oximo-expr 0.1.0

Arena-allocated expression tree for the oximo optimization framework
Documentation
#![allow(clippy::float_cmp)]

use std::cell::RefCell;

use oximo_expr::{Expr, ExprArena, ExprNode, VarId, evaluate, extract_linear};

fn make_var(arena: &RefCell<ExprArena>, idx: u32) -> Expr<'_> {
    Expr::from_var(arena, VarId(idx))
}

#[test]
fn linear_fast_path_collapses_add() {
    let arena = RefCell::new(ExprArena::new());
    let x = make_var(&arena, 0);
    let y = make_var(&arena, 1);

    let combo = 2.0 * x + 3.0 * y + 5.0;

    let snapshot = arena.borrow().get(combo.id).clone();
    match snapshot {
        ExprNode::Linear { coeffs, constant } => {
            assert_eq!(constant, 5.0);
            assert_eq!(coeffs.len(), 2);
            let mut sorted = coeffs;
            sorted.sort_by_key(|(v, _)| v.0);
            assert_eq!(sorted, vec![(VarId(0), 2.0), (VarId(1), 3.0)]);
        }
        n => panic!("expected Linear node, got {n:?}"),
    }
}

#[test]
fn linear_fast_path_handles_subtraction() {
    let arena = RefCell::new(ExprArena::new());
    let x = make_var(&arena, 0);
    let y = make_var(&arena, 1);

    let combo = 4.0 * x - y - 1.0;
    let terms = extract_linear(&arena.borrow(), combo.id).expect("must be linear");
    assert_eq!(terms.constant, -1.0);
    let mut sorted = terms.coeffs;
    sorted.sort_by_key(|(v, _)| v.0);
    assert_eq!(sorted, vec![(VarId(0), 4.0), (VarId(1), -1.0)]);
}

#[test]
fn nonlinear_pow_is_not_linear() {
    let arena = RefCell::new(ExprArena::new());
    let x = make_var(&arena, 0);
    let combo = x.powi(2) + x;
    assert!(extract_linear(&arena.borrow(), combo.id).is_none());
}

#[test]
fn evaluate_recovers_value() {
    let arena = RefCell::new(ExprArena::new());
    let x = make_var(&arena, 0);
    let y = make_var(&arena, 1);
    let combo = 2.0 * x + 3.0 * y + 5.0;
    let values: &[f64] = &[10.0, 7.0];
    let arena_ref = arena.borrow();
    let v = evaluate(&arena_ref, combo.id, &values).unwrap();
    assert_eq!(v, 2.0 * 10.0 + 3.0 * 7.0 + 5.0);
}

#[test]
fn negation_flips_coefficients() {
    let arena = RefCell::new(ExprArena::new());
    let x = make_var(&arena, 0);
    let y = make_var(&arena, 1);
    let combo = -(2.0 * x + 3.0 * y + 5.0);
    let terms = extract_linear(&arena.borrow(), combo.id).expect("linear");
    assert_eq!(terms.constant, -5.0);
    let mut sorted = terms.coeffs;
    sorted.sort_by_key(|(v, _)| v.0);
    assert_eq!(sorted, vec![(VarId(0), -2.0), (VarId(1), -3.0)]);
}

#[test]
fn large_sum_extracts_correctly() {
    let arena = RefCell::new(ExprArena::new());
    let vars: Vec<_> = (0..100).map(|i| make_var(&arena, i)).collect();
    let total = oximo_expr::sum(vars.iter().copied());
    let terms = extract_linear(&arena.borrow(), total.id).expect("linear");
    assert_eq!(terms.constant, 0.0);
    assert_eq!(terms.coeffs.len(), 100);
    for (_, c) in &terms.coeffs {
        assert!((*c - 1.0).abs() < f64::EPSILON);
    }
}