1use rustc_hash::{FxBuildHasher, FxHashMap};
2use smallvec::smallvec;
3
4use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
5
6#[derive(Clone, Debug, Default)]
8pub struct LinearTerms {
9 pub coeffs: Vec<(VarId, f64)>,
10 pub constant: f64,
11}
12
13fn as_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
16 match arena.get(id) {
17 ExprNode::Const(c) => Some(LinearTerms { coeffs: Vec::new(), constant: *c }),
18 ExprNode::Var(v) => Some(LinearTerms { coeffs: vec![(*v, 1.0)], constant: 0.0 }),
19 ExprNode::Linear { coeffs, constant } => {
20 Some(LinearTerms { coeffs: coeffs.clone(), constant: *constant })
21 }
22 ExprNode::Neg(inner) => {
23 let inner = *inner;
24 as_linear(arena, inner).map(|mut t| {
25 t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
26 t.constant = -t.constant;
27 t
28 })
29 }
30 ExprNode::Add(children) => {
31 let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
32 let mut acc = LinearTerms::default();
33 let mut map: FxHashMap<VarId, f64> =
34 FxHashMap::with_capacity_and_hasher(children.len() * 4, FxBuildHasher);
35 for child in children {
36 let t = as_linear(arena, child)?;
37 for (v, c) in t.coeffs {
38 *map.entry(v).or_insert(0.0) += c;
39 }
40 acc.constant += t.constant;
41 }
42 acc.coeffs = map.into_iter().collect();
43 Some(acc)
44 }
45 ExprNode::Mul(children) => {
46 let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
48 let mut scalar = 1.0;
49 let mut linear: Option<LinearTerms> = None;
50 for child in children {
51 if let ExprNode::Const(c) = arena.get(child) {
52 scalar *= c;
53 } else if linear.is_none() {
54 linear = Some(as_linear(arena, child)?);
55 } else {
56 return None;
57 }
58 }
59 Some(match linear {
60 None => LinearTerms { coeffs: Vec::new(), constant: scalar },
61 Some(mut t) => {
62 t.coeffs.iter_mut().for_each(|(_, c)| *c *= scalar);
63 t.constant *= scalar;
64 t
65 }
66 })
67 }
68 _ => None,
69 }
70}
71
72fn push_linear(arena: &mut ExprArena, mut t: LinearTerms) -> ExprId {
74 t.coeffs.retain(|(_, c)| *c != 0.0);
75 arena.push(ExprNode::Linear { coeffs: t.coeffs, constant: t.constant })
76}
77
78pub(crate) fn add_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
81 if let (Some(lt), Some(rt)) = (as_linear(arena, lhs), as_linear(arena, rhs)) {
82 let mut map: FxHashMap<VarId, f64> =
83 FxHashMap::with_capacity_and_hasher(lt.coeffs.len() + rt.coeffs.len(), FxBuildHasher);
84 for (v, c) in lt.coeffs.into_iter().chain(rt.coeffs) {
85 *map.entry(v).or_insert(0.0) += c;
86 }
87 return push_linear(
88 arena,
89 LinearTerms { coeffs: map.into_iter().collect(), constant: lt.constant + rt.constant },
90 );
91 }
92 arena.push(ExprNode::Add(smallvec![lhs, rhs]))
93}
94
95pub(crate) fn sub_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
97 let neg = neg_into(arena, rhs);
98 add_into(arena, lhs, neg)
99}
100
101pub(crate) fn mul_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
104 if let ExprNode::Const(c) = *arena.get(lhs) {
105 if let Some(mut t) = as_linear(arena, rhs) {
106 t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
107 t.constant *= c;
108 return push_linear(arena, t);
109 }
110 }
111 if let ExprNode::Const(c) = *arena.get(rhs) {
112 if let Some(mut t) = as_linear(arena, lhs) {
113 t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
114 t.constant *= c;
115 return push_linear(arena, t);
116 }
117 }
118 arena.push(ExprNode::Mul(smallvec![lhs, rhs]))
119}
120
121pub(crate) fn neg_into(arena: &mut ExprArena, rhs: ExprId) -> ExprId {
123 if let Some(mut t) = as_linear(arena, rhs) {
124 t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
125 t.constant = -t.constant;
126 return push_linear(arena, t);
127 }
128 arena.push(ExprNode::Neg(rhs))
129}
130
131pub fn extract_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
134 as_linear(arena, id)
135}