use crate::arena::{ExprArena, ExprId, ExprNode};
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum ExprClass {
Linear,
Quadratic,
Nonlinear,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum Degree {
Zero,
One,
Two,
Higher,
}
impl Degree {
fn add(self, other: Degree) -> Degree {
self.max(other)
}
fn mul(self, other: Degree) -> Degree {
match (self, other) {
(Degree::Higher, _) | (_, Degree::Higher) => Degree::Higher,
(Degree::Zero, x) | (x, Degree::Zero) => x,
(Degree::One, Degree::One) => Degree::Two,
_ => Degree::Higher,
}
}
fn pow(self, n: u32) -> Degree {
match (self, n) {
(_, 0) | (Degree::Zero, _) => Degree::Zero,
(d, 1) => d,
(Degree::One, 2) => Degree::Two,
_ => Degree::Higher,
}
}
}
fn degree(arena: &ExprArena, id: ExprId) -> Degree {
match arena.get(id) {
ExprNode::Const(_) | ExprNode::Param(_) => Degree::Zero,
ExprNode::Var(_) | ExprNode::Linear { .. } => Degree::One,
ExprNode::Neg(inner) => degree(arena, *inner),
ExprNode::Add(children) => {
let mut d = Degree::Zero;
for c in children {
d = d.add(degree(arena, *c));
if d == Degree::Higher {
return d;
}
}
d
}
ExprNode::Mul(children) => {
let mut d = Degree::Zero;
for c in children {
d = d.mul(degree(arena, *c));
if d == Degree::Higher {
return d;
}
}
d
}
ExprNode::Pow(base, exp) => {
let ExprNode::Const(e) = arena.get(*exp) else { return Degree::Higher };
if (*e - e.round()).abs() >= f64::EPSILON || *e < 0.0 {
return Degree::Higher;
}
let n = match e.round() {
v if v < 0.5 => 0,
v if v < 1.5 => 1,
v if v < 2.5 => 2,
_ => 3,
};
degree(arena, *base).pow(n)
}
ExprNode::Div(_, _)
| ExprNode::Sin(_)
| ExprNode::Cos(_)
| ExprNode::Exp(_)
| ExprNode::Log(_)
| ExprNode::Abs(_) => Degree::Higher,
}
}
pub fn classify(arena: &ExprArena, id: ExprId) -> ExprClass {
match degree(arena, id) {
Degree::Zero | Degree::One => ExprClass::Linear,
Degree::Two => ExprClass::Quadratic,
Degree::Higher => ExprClass::Nonlinear,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::{ExprArena, ExprNode, VarId};
use smallvec::smallvec;
fn var(arena: &mut ExprArena, i: u32) -> ExprId {
arena.push(ExprNode::Var(VarId(i)))
}
#[test]
fn linear_var_sum() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let y = var(&mut a, 1);
let sum = a.push(ExprNode::Add(smallvec![x, y]));
assert_eq!(classify(&a, sum), ExprClass::Linear);
}
#[test]
fn quadratic_mul_two_vars() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let y = var(&mut a, 1);
let xy = a.push(ExprNode::Mul(smallvec![x, y]));
assert_eq!(classify(&a, xy), ExprClass::Quadratic);
}
#[test]
fn quadratic_pow_two() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let two = a.push(ExprNode::Const(2.0));
let sq = a.push(ExprNode::Pow(x, two));
assert_eq!(classify(&a, sq), ExprClass::Quadratic);
}
#[test]
fn nonlinear_pow_three() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let three = a.push(ExprNode::Const(3.0));
let cube = a.push(ExprNode::Pow(x, three));
assert_eq!(classify(&a, cube), ExprClass::Nonlinear);
}
#[test]
fn nonlinear_div() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let y = var(&mut a, 1);
let q = a.push(ExprNode::Div(x, y));
assert_eq!(classify(&a, q), ExprClass::Nonlinear);
}
#[test]
fn nonlinear_sin() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let s = a.push(ExprNode::Sin(x));
assert_eq!(classify(&a, s), ExprClass::Nonlinear);
}
#[test]
fn nonlinear_abs() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let s = a.push(ExprNode::Abs(x));
assert_eq!(classify(&a, s), ExprClass::Nonlinear);
}
#[test]
fn nonlinear_triple_mul() {
let mut arena = ExprArena::new();
let x = var(&mut arena, 0);
let y = var(&mut arena, 1);
let z = var(&mut arena, 2);
let prod = arena.push(ExprNode::Mul(smallvec![x, y, z]));
assert_eq!(classify(&arena, prod), ExprClass::Nonlinear);
}
#[test]
fn linear_promoted_by_const_mul() {
let mut a = ExprArena::new();
let x = var(&mut a, 0);
let c = a.push(ExprNode::Const(3.0));
let m = a.push(ExprNode::Mul(smallvec![c, x]));
assert_eq!(classify(&a, m), ExprClass::Linear);
}
#[test]
fn param_alone_is_linear() {
let mut a = ExprArena::new();
let p = a.new_param(4.0);
let pn = a.param(p);
assert_eq!(classify(&a, pn), ExprClass::Linear);
}
#[test]
fn param_times_var_is_linear() {
let mut a = ExprArena::new();
let p = a.new_param(4.0);
let pn = a.param(p);
let x = var(&mut a, 0);
let m = a.push(ExprNode::Mul(smallvec![pn, x]));
assert_eq!(classify(&a, m), ExprClass::Linear);
}
#[test]
fn param_times_var_squared_is_quadratic() {
let mut a = ExprArena::new();
let p = a.new_param(4.0);
let pn = a.param(p);
let x = var(&mut a, 0);
let two = a.push(ExprNode::Const(2.0));
let sq = a.push(ExprNode::Pow(x, two));
let m = a.push(ExprNode::Mul(smallvec![pn, sq]));
assert_eq!(classify(&a, m), ExprClass::Quadratic);
}
}