use rustc_hash::{FxBuildHasher, FxHashMap};
use smallvec::smallvec;
use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
#[derive(Clone, Debug, Default)]
pub struct LinearTerms {
pub coeffs: Vec<(VarId, f64)>,
pub constant: f64,
}
fn as_linear(arena: &ExprArena, id: ExprId, resolve_params: bool) -> Option<LinearTerms> {
match arena.get(id) {
ExprNode::Const(c) => Some(LinearTerms { coeffs: Vec::new(), constant: *c }),
ExprNode::Param(p) if resolve_params => {
Some(LinearTerms { coeffs: Vec::new(), constant: arena.param_value(*p) })
}
ExprNode::Var(v) => Some(LinearTerms { coeffs: vec![(*v, 1.0)], constant: 0.0 }),
ExprNode::Linear { coeffs, constant } => {
Some(LinearTerms { coeffs: coeffs.clone(), constant: *constant })
}
ExprNode::Neg(inner) => {
let inner = *inner;
as_linear(arena, inner, resolve_params).map(|mut t| {
t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
t.constant = -t.constant;
t
})
}
ExprNode::Add(children) => {
let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
let mut acc = LinearTerms::default();
let mut map: FxHashMap<VarId, f64> =
FxHashMap::with_capacity_and_hasher(children.len() * 4, FxBuildHasher);
for child in children {
let t = as_linear(arena, child, resolve_params)?;
for (v, c) in t.coeffs {
*map.entry(v).or_insert(0.0) += c;
}
acc.constant += t.constant;
}
acc.coeffs = map.into_iter().collect();
Some(acc)
}
ExprNode::Mul(children) => {
let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
let mut scalar = 1.0;
let mut linear: Option<LinearTerms> = None;
for child in children {
match arena.get(child) {
ExprNode::Const(c) => scalar *= c,
ExprNode::Param(p) if resolve_params => scalar *= arena.param_value(*p),
_ if linear.is_none() => {
linear = Some(as_linear(arena, child, resolve_params)?);
}
_ => return None,
}
}
Some(match linear {
None => LinearTerms { coeffs: Vec::new(), constant: scalar },
Some(mut t) => {
t.coeffs.iter_mut().for_each(|(_, c)| *c *= scalar);
t.constant *= scalar;
t
}
})
}
_ => None,
}
}
fn push_linear(arena: &mut ExprArena, mut t: LinearTerms) -> ExprId {
t.coeffs.retain(|(_, c)| *c != 0.0);
arena.push(ExprNode::Linear { coeffs: t.coeffs, constant: t.constant })
}
pub(crate) fn add_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
if let (Some(lt), Some(rt)) = (as_linear(arena, lhs, false), as_linear(arena, rhs, false)) {
let mut map: FxHashMap<VarId, f64> =
FxHashMap::with_capacity_and_hasher(lt.coeffs.len() + rt.coeffs.len(), FxBuildHasher);
for (v, c) in lt.coeffs.into_iter().chain(rt.coeffs) {
*map.entry(v).or_insert(0.0) += c;
}
return push_linear(
arena,
LinearTerms { coeffs: map.into_iter().collect(), constant: lt.constant + rt.constant },
);
}
arena.push(ExprNode::Add(smallvec![lhs, rhs]))
}
pub(crate) fn sub_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
let neg = neg_into(arena, rhs);
add_into(arena, lhs, neg)
}
pub(crate) fn mul_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
if let ExprNode::Const(c) = *arena.get(lhs) {
if let Some(mut t) = as_linear(arena, rhs, false) {
t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
t.constant *= c;
return push_linear(arena, t);
}
}
if let ExprNode::Const(c) = *arena.get(rhs) {
if let Some(mut t) = as_linear(arena, lhs, false) {
t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
t.constant *= c;
return push_linear(arena, t);
}
}
arena.push(ExprNode::Mul(smallvec![lhs, rhs]))
}
pub(crate) fn div_into(arena: &mut ExprArena, num: ExprId, den: ExprId) -> ExprId {
if let ExprNode::Const(c) = *arena.get(den) {
if c != 0.0 {
if let Some(mut t) = as_linear(arena, num, false) {
let inv = 1.0 / c;
t.coeffs.iter_mut().for_each(|(_, co)| *co *= inv);
t.constant *= inv;
return push_linear(arena, t);
}
let inv = arena.push(ExprNode::Const(1.0 / c));
return mul_into(arena, num, inv);
}
}
arena.push(ExprNode::Div(num, den))
}
pub(crate) fn neg_into(arena: &mut ExprArena, rhs: ExprId) -> ExprId {
if let Some(mut t) = as_linear(arena, rhs, false) {
t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
t.constant = -t.constant;
return push_linear(arena, t);
}
arena.push(ExprNode::Neg(rhs))
}
pub fn extract_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
as_linear(arena, id, true)
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct SignedExpr {
pub id: ExprId,
pub neg: bool,
}
pub fn split_linear(arena: &ExprArena, id: ExprId) -> (LinearTerms, Vec<SignedExpr>) {
if let Some(lt) = as_linear(arena, id, true) {
return (lt, Vec::new());
}
let mut lin = LinearTerms::default();
let mut residual: Vec<SignedExpr> = Vec::new();
let mut sign_stack: smallvec::SmallVec<[(ExprId, f64); 8]> = smallvec![(id, 1.0)];
while let Some((cur, sign)) = sign_stack.pop() {
match arena.get(cur) {
ExprNode::Add(children) => {
for c in children.iter().copied() {
sign_stack.push((c, sign));
}
}
ExprNode::Neg(inner) => sign_stack.push((*inner, -sign)),
_ => {
if let Some(mut t) = as_linear(arena, cur, true) {
if (sign - 1.0).abs() > 0.0 {
t.coeffs.iter_mut().for_each(|(_, c)| *c *= sign);
t.constant *= sign;
}
for (v, c) in t.coeffs {
if let Some((_, acc)) = lin.coeffs.iter_mut().find(|(vv, _)| *vv == v) {
*acc += c;
} else {
lin.coeffs.push((v, c));
}
}
lin.constant += t.constant;
} else {
residual.push(SignedExpr { id: cur, neg: sign < 0.0 });
}
}
}
}
lin.coeffs.retain(|(_, c)| *c != 0.0);
(lin, residual)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arena::{ExprArena, ExprNode, VarId};
#[test]
fn param_times_var_stays_symbolic_until_extracted() {
let mut arena = ExprArena::new();
let pid = arena.new_param(3.0);
let price = arena.param(pid);
let xnode = arena.push(ExprNode::Var(VarId(0)));
let prod = mul_into(&mut arena, price, xnode);
assert!(matches!(arena.get(prod), ExprNode::Mul(_)));
let terms = extract_linear(&arena, prod).expect("linear");
assert_eq!(terms.coeffs, vec![(VarId(0), 3.0)]);
assert!(terms.constant.abs() < f64::EPSILON);
}
#[test]
fn rebinding_param_updates_extracted_coeff() {
let mut arena = ExprArena::new();
let pid = arena.new_param(3.0);
let price = arena.param(pid);
let xnode = arena.push(ExprNode::Var(VarId(0)));
let prod = mul_into(&mut arena, price, xnode);
arena.set_param_value(pid, 10.0);
let terms = extract_linear(&arena, prod).expect("linear");
assert_eq!(terms.coeffs, vec![(VarId(0), 10.0)]);
}
#[test]
fn param_plus_var_resolves_constant() {
let mut arena = ExprArena::new();
let pid = arena.new_param(5.0);
let price = arena.param(pid);
let xnode = arena.push(ExprNode::Var(VarId(0)));
let sum = add_into(&mut arena, price, xnode);
let terms = extract_linear(&arena, sum).expect("linear");
assert_eq!(terms.coeffs, vec![(VarId(0), 1.0)]);
assert!((terms.constant - 5.0).abs() < f64::EPSILON);
}
}