1use crate::arena::{ExprArena, ExprId, ExprNode};
2
3#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
11pub enum ExprClass {
12 Linear,
13 Quadratic,
14 Nonlinear,
15}
16
17#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
21enum Degree {
22 Zero,
23 One,
24 Two,
25 Higher,
26}
27
28impl Degree {
29 fn add(self, other: Degree) -> Degree {
31 self.max(other)
32 }
33
34 fn mul(self, other: Degree) -> Degree {
36 match (self, other) {
37 (Degree::Higher, _) | (_, Degree::Higher) => Degree::Higher,
38 (Degree::Zero, x) | (x, Degree::Zero) => x,
39 (Degree::One, Degree::One) => Degree::Two,
40 _ => Degree::Higher,
41 }
42 }
43
44 fn pow(self, n: u32) -> Degree {
46 match (self, n) {
47 (_, 0) | (Degree::Zero, _) => Degree::Zero,
48 (d, 1) => d,
49 (Degree::One, 2) => Degree::Two,
50 _ => Degree::Higher,
51 }
52 }
53}
54
55fn degree(arena: &ExprArena, id: ExprId) -> Degree {
56 match arena.get(id) {
57 ExprNode::Const(_) | ExprNode::Param(_) => Degree::Zero,
58 ExprNode::Var(_) | ExprNode::Linear { .. } => Degree::One,
59 ExprNode::Neg(inner) => degree(arena, *inner),
60 ExprNode::Add(children) => {
61 let mut d = Degree::Zero;
62 for c in children {
63 d = d.add(degree(arena, *c));
64 if d == Degree::Higher {
65 return d;
66 }
67 }
68 d
69 }
70 ExprNode::Mul(children) => {
71 let mut d = Degree::Zero;
72 for c in children {
73 d = d.mul(degree(arena, *c));
74 if d == Degree::Higher {
75 return d;
76 }
77 }
78 d
79 }
80 ExprNode::Pow(base, exp) => {
81 let ExprNode::Const(e) = arena.get(*exp) else { return Degree::Higher };
82 if (*e - e.round()).abs() >= f64::EPSILON || *e < 0.0 {
83 return Degree::Higher;
84 }
85 let n = match e.round() {
88 v if v < 0.5 => 0,
89 v if v < 1.5 => 1,
90 v if v < 2.5 => 2,
91 _ => 3,
92 };
93 degree(arena, *base).pow(n)
94 }
95 ExprNode::Div(_, _)
100 | ExprNode::Sin(_)
101 | ExprNode::Cos(_)
102 | ExprNode::Exp(_)
103 | ExprNode::Log(_)
104 | ExprNode::Abs(_) => Degree::Higher,
105 }
106}
107
108pub fn classify(arena: &ExprArena, id: ExprId) -> ExprClass {
112 match degree(arena, id) {
113 Degree::Zero | Degree::One => ExprClass::Linear,
114 Degree::Two => ExprClass::Quadratic,
115 Degree::Higher => ExprClass::Nonlinear,
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122 use crate::arena::{ExprArena, ExprNode, VarId};
123 use smallvec::smallvec;
124
125 fn var(arena: &mut ExprArena, i: u32) -> ExprId {
126 arena.push(ExprNode::Var(VarId(i)))
127 }
128
129 #[test]
130 fn linear_var_sum() {
131 let mut a = ExprArena::new();
132 let x = var(&mut a, 0);
133 let y = var(&mut a, 1);
134 let sum = a.push(ExprNode::Add(smallvec![x, y]));
135 assert_eq!(classify(&a, sum), ExprClass::Linear);
136 }
137
138 #[test]
139 fn quadratic_mul_two_vars() {
140 let mut a = ExprArena::new();
141 let x = var(&mut a, 0);
142 let y = var(&mut a, 1);
143 let xy = a.push(ExprNode::Mul(smallvec![x, y]));
144 assert_eq!(classify(&a, xy), ExprClass::Quadratic);
145 }
146
147 #[test]
148 fn quadratic_pow_two() {
149 let mut a = ExprArena::new();
150 let x = var(&mut a, 0);
151 let two = a.push(ExprNode::Const(2.0));
152 let sq = a.push(ExprNode::Pow(x, two));
153 assert_eq!(classify(&a, sq), ExprClass::Quadratic);
154 }
155
156 #[test]
157 fn nonlinear_pow_three() {
158 let mut a = ExprArena::new();
159 let x = var(&mut a, 0);
160 let three = a.push(ExprNode::Const(3.0));
161 let cube = a.push(ExprNode::Pow(x, three));
162 assert_eq!(classify(&a, cube), ExprClass::Nonlinear);
163 }
164
165 #[test]
166 fn nonlinear_div() {
167 let mut a = ExprArena::new();
168 let x = var(&mut a, 0);
169 let y = var(&mut a, 1);
170 let q = a.push(ExprNode::Div(x, y));
171 assert_eq!(classify(&a, q), ExprClass::Nonlinear);
172 }
173
174 #[test]
175 fn nonlinear_sin() {
176 let mut a = ExprArena::new();
177 let x = var(&mut a, 0);
178 let s = a.push(ExprNode::Sin(x));
179 assert_eq!(classify(&a, s), ExprClass::Nonlinear);
180 }
181
182 #[test]
183 fn nonlinear_abs() {
184 let mut a = ExprArena::new();
185 let x = var(&mut a, 0);
186 let s = a.push(ExprNode::Abs(x));
187 assert_eq!(classify(&a, s), ExprClass::Nonlinear);
188 }
189
190 #[test]
191 fn nonlinear_triple_mul() {
192 let mut arena = ExprArena::new();
193 let x = var(&mut arena, 0);
194 let y = var(&mut arena, 1);
195 let z = var(&mut arena, 2);
196 let prod = arena.push(ExprNode::Mul(smallvec![x, y, z]));
197 assert_eq!(classify(&arena, prod), ExprClass::Nonlinear);
198 }
199
200 #[test]
201 fn linear_promoted_by_const_mul() {
202 let mut a = ExprArena::new();
203 let x = var(&mut a, 0);
204 let c = a.push(ExprNode::Const(3.0));
205 let m = a.push(ExprNode::Mul(smallvec![c, x]));
206 assert_eq!(classify(&a, m), ExprClass::Linear);
207 }
208
209 #[test]
210 fn param_alone_is_linear() {
211 let mut a = ExprArena::new();
212 let p = a.new_param(4.0);
213 let pn = a.param(p);
214 assert_eq!(classify(&a, pn), ExprClass::Linear);
215 }
216
217 #[test]
218 fn param_times_var_is_linear() {
219 let mut a = ExprArena::new();
220 let p = a.new_param(4.0);
221 let pn = a.param(p);
222 let x = var(&mut a, 0);
223 let m = a.push(ExprNode::Mul(smallvec![pn, x]));
224 assert_eq!(classify(&a, m), ExprClass::Linear);
225 }
226
227 #[test]
228 fn param_times_var_squared_is_quadratic() {
229 let mut a = ExprArena::new();
230 let p = a.new_param(4.0);
231 let pn = a.param(p);
232 let x = var(&mut a, 0);
233 let two = a.push(ExprNode::Const(2.0));
234 let sq = a.push(ExprNode::Pow(x, two));
235 let m = a.push(ExprNode::Mul(smallvec![pn, sq]));
236 assert_eq!(classify(&a, m), ExprClass::Quadratic);
237 }
238}