Skip to main content

oximo_expr/
classify.rs

1use crate::arena::{ExprArena, ExprId, ExprNode};
2
3/// Highest-degree polynomial class an expression belongs to, ignoring constant
4/// folding. Used by backends to pick between linear, quadratic, and general
5/// nonlinear translation paths.
6///
7/// Variants are ordered by increasing degree, so `max` of two classes yields the
8/// dominating one (e.g. a model with a quadratic objective and a nonlinear
9/// constraint is `Nonlinear`).
10#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
11pub enum ExprClass {
12    Linear,
13    Quadratic,
14    Nonlinear,
15}
16
17/// Polynomial-degree bucket. `Higher` is a saturating sentinel for "anything
18/// above quadratic". Both polynomial degree > 2 and transcendentals collapse
19/// into it, since neither fits a QP solver's quadratic API.
20#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
21enum Degree {
22    Zero,
23    One,
24    Two,
25    Higher,
26}
27
28impl Degree {
29    /// `+` on a sum: take the maximum, saturating at `Higher`.
30    fn add(self, other: Degree) -> Degree {
31        self.max(other)
32    }
33
34    /// `*` on a product: add ordinal degrees, saturating at `Higher`.
35    fn mul(self, other: Degree) -> Degree {
36        match (self, other) {
37            (Degree::Higher, _) | (_, Degree::Higher) => Degree::Higher,
38            (Degree::Zero, x) | (x, Degree::Zero) => x,
39            (Degree::One, Degree::One) => Degree::Two,
40            _ => Degree::Higher,
41        }
42    }
43
44    /// `^n` on a power: multiply by `n`, saturating at `Higher`.
45    fn pow(self, n: u32) -> Degree {
46        match (self, n) {
47            (_, 0) | (Degree::Zero, _) => Degree::Zero,
48            (d, 1) => d,
49            (Degree::One, 2) => Degree::Two,
50            _ => Degree::Higher,
51        }
52    }
53}
54
55fn degree(arena: &ExprArena, id: ExprId) -> Degree {
56    match arena.get(id) {
57        ExprNode::Const(_) | ExprNode::Param(_) => Degree::Zero,
58        ExprNode::Var(_) | ExprNode::Linear { .. } => Degree::One,
59        ExprNode::Neg(inner) => degree(arena, *inner),
60        ExprNode::Add(children) => {
61            let mut d = Degree::Zero;
62            for c in children {
63                d = d.add(degree(arena, *c));
64                if d == Degree::Higher {
65                    return d;
66                }
67            }
68            d
69        }
70        ExprNode::Mul(children) => {
71            let mut d = Degree::Zero;
72            for c in children {
73                d = d.mul(degree(arena, *c));
74                if d == Degree::Higher {
75                    return d;
76                }
77            }
78            d
79        }
80        ExprNode::Pow(base, exp) => {
81            let ExprNode::Const(e) = arena.get(*exp) else { return Degree::Higher };
82            if (*e - e.round()).abs() >= f64::EPSILON || *e < 0.0 {
83                return Degree::Higher;
84            }
85            // Bucket the exponent into the only values `Degree::pow` treats
86            // distinctly.
87            let n = match e.round() {
88                v if v < 0.5 => 0,
89                v if v < 1.5 => 1,
90                v if v < 2.5 => 2,
91                _ => 3,
92            };
93            degree(arena, *base).pow(n)
94        }
95        // Transcendentals are always > quadratic. Division is too: `div_into`
96        // folds the only degree-preserving case (constant denominator) before a
97        // `Div` node is created, so any other `Div` has a non-constant
98        // denominator.
99        ExprNode::Div(_, _)
100        | ExprNode::Sin(_)
101        | ExprNode::Cos(_)
102        | ExprNode::Exp(_)
103        | ExprNode::Log(_)
104        | ExprNode::Abs(_) => Degree::Higher,
105    }
106}
107
108/// Classify an expression as Linear, Quadratic (polynomial degree <= 2 with at
109/// least one degree-2 term), or Nonlinear (transcendentals, non-integer powers,
110/// or polynomial degree > 2).
111pub fn classify(arena: &ExprArena, id: ExprId) -> ExprClass {
112    match degree(arena, id) {
113        Degree::Zero | Degree::One => ExprClass::Linear,
114        Degree::Two => ExprClass::Quadratic,
115        Degree::Higher => ExprClass::Nonlinear,
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122    use crate::arena::{ExprArena, ExprNode, VarId};
123    use smallvec::smallvec;
124
125    fn var(arena: &mut ExprArena, i: u32) -> ExprId {
126        arena.push(ExprNode::Var(VarId(i)))
127    }
128
129    #[test]
130    fn linear_var_sum() {
131        let mut a = ExprArena::new();
132        let x = var(&mut a, 0);
133        let y = var(&mut a, 1);
134        let sum = a.push(ExprNode::Add(smallvec![x, y]));
135        assert_eq!(classify(&a, sum), ExprClass::Linear);
136    }
137
138    #[test]
139    fn quadratic_mul_two_vars() {
140        let mut a = ExprArena::new();
141        let x = var(&mut a, 0);
142        let y = var(&mut a, 1);
143        let xy = a.push(ExprNode::Mul(smallvec![x, y]));
144        assert_eq!(classify(&a, xy), ExprClass::Quadratic);
145    }
146
147    #[test]
148    fn quadratic_pow_two() {
149        let mut a = ExprArena::new();
150        let x = var(&mut a, 0);
151        let two = a.push(ExprNode::Const(2.0));
152        let sq = a.push(ExprNode::Pow(x, two));
153        assert_eq!(classify(&a, sq), ExprClass::Quadratic);
154    }
155
156    #[test]
157    fn nonlinear_pow_three() {
158        let mut a = ExprArena::new();
159        let x = var(&mut a, 0);
160        let three = a.push(ExprNode::Const(3.0));
161        let cube = a.push(ExprNode::Pow(x, three));
162        assert_eq!(classify(&a, cube), ExprClass::Nonlinear);
163    }
164
165    #[test]
166    fn nonlinear_div() {
167        let mut a = ExprArena::new();
168        let x = var(&mut a, 0);
169        let y = var(&mut a, 1);
170        let q = a.push(ExprNode::Div(x, y));
171        assert_eq!(classify(&a, q), ExprClass::Nonlinear);
172    }
173
174    #[test]
175    fn nonlinear_sin() {
176        let mut a = ExprArena::new();
177        let x = var(&mut a, 0);
178        let s = a.push(ExprNode::Sin(x));
179        assert_eq!(classify(&a, s), ExprClass::Nonlinear);
180    }
181
182    #[test]
183    fn nonlinear_abs() {
184        let mut a = ExprArena::new();
185        let x = var(&mut a, 0);
186        let s = a.push(ExprNode::Abs(x));
187        assert_eq!(classify(&a, s), ExprClass::Nonlinear);
188    }
189
190    #[test]
191    fn nonlinear_triple_mul() {
192        let mut arena = ExprArena::new();
193        let x = var(&mut arena, 0);
194        let y = var(&mut arena, 1);
195        let z = var(&mut arena, 2);
196        let prod = arena.push(ExprNode::Mul(smallvec![x, y, z]));
197        assert_eq!(classify(&arena, prod), ExprClass::Nonlinear);
198    }
199
200    #[test]
201    fn linear_promoted_by_const_mul() {
202        let mut a = ExprArena::new();
203        let x = var(&mut a, 0);
204        let c = a.push(ExprNode::Const(3.0));
205        let m = a.push(ExprNode::Mul(smallvec![c, x]));
206        assert_eq!(classify(&a, m), ExprClass::Linear);
207    }
208
209    #[test]
210    fn param_alone_is_linear() {
211        let mut a = ExprArena::new();
212        let p = a.new_param(4.0);
213        let pn = a.param(p);
214        assert_eq!(classify(&a, pn), ExprClass::Linear);
215    }
216
217    #[test]
218    fn param_times_var_is_linear() {
219        let mut a = ExprArena::new();
220        let p = a.new_param(4.0);
221        let pn = a.param(p);
222        let x = var(&mut a, 0);
223        let m = a.push(ExprNode::Mul(smallvec![pn, x]));
224        assert_eq!(classify(&a, m), ExprClass::Linear);
225    }
226
227    #[test]
228    fn param_times_var_squared_is_quadratic() {
229        let mut a = ExprArena::new();
230        let p = a.new_param(4.0);
231        let pn = a.param(p);
232        let x = var(&mut a, 0);
233        let two = a.push(ExprNode::Const(2.0));
234        let sq = a.push(ExprNode::Pow(x, two));
235        let m = a.push(ExprNode::Mul(smallvec![pn, sq]));
236        assert_eq!(classify(&a, m), ExprClass::Quadratic);
237    }
238}