1use rustc_hash::{FxBuildHasher, FxHashMap};
2use smallvec::smallvec;
3
4use crate::arena::{ExprArena, ExprId, ExprNode, VarId};
5
6#[derive(Clone, Debug, Default)]
8pub struct LinearTerms {
9 pub coeffs: Vec<(VarId, f64)>,
10 pub constant: f64,
11}
12
13fn as_linear(arena: &ExprArena, id: ExprId, resolve_params: bool) -> Option<LinearTerms> {
19 match arena.get(id) {
20 ExprNode::Const(c) => Some(LinearTerms { coeffs: Vec::new(), constant: *c }),
21 ExprNode::Param(p) if resolve_params => {
22 Some(LinearTerms { coeffs: Vec::new(), constant: arena.param_value(*p) })
23 }
24 ExprNode::Var(v) => Some(LinearTerms { coeffs: vec![(*v, 1.0)], constant: 0.0 }),
25 ExprNode::Linear { coeffs, constant } => {
26 Some(LinearTerms { coeffs: coeffs.clone(), constant: *constant })
27 }
28 ExprNode::Neg(inner) => {
29 let inner = *inner;
30 as_linear(arena, inner, resolve_params).map(|mut t| {
31 t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
32 t.constant = -t.constant;
33 t
34 })
35 }
36 ExprNode::Add(children) => {
37 let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
38 let mut acc = LinearTerms::default();
39 let mut map: FxHashMap<VarId, f64> =
40 FxHashMap::with_capacity_and_hasher(children.len() * 4, FxBuildHasher);
41 for child in children {
42 let t = as_linear(arena, child, resolve_params)?;
43 for (v, c) in t.coeffs {
44 *map.entry(v).or_insert(0.0) += c;
45 }
46 acc.constant += t.constant;
47 }
48 acc.coeffs = map.into_iter().collect();
49 Some(acc)
50 }
51 ExprNode::Mul(children) => {
52 let children: smallvec::SmallVec<[ExprId; 4]> = children.iter().copied().collect();
54 let mut scalar = 1.0;
55 let mut linear: Option<LinearTerms> = None;
56 for child in children {
57 match arena.get(child) {
58 ExprNode::Const(c) => scalar *= c,
59 ExprNode::Param(p) if resolve_params => scalar *= arena.param_value(*p),
60 _ if linear.is_none() => {
61 linear = Some(as_linear(arena, child, resolve_params)?);
62 }
63 _ => return None,
64 }
65 }
66 Some(match linear {
67 None => LinearTerms { coeffs: Vec::new(), constant: scalar },
68 Some(mut t) => {
69 t.coeffs.iter_mut().for_each(|(_, c)| *c *= scalar);
70 t.constant *= scalar;
71 t
72 }
73 })
74 }
75 _ => None,
76 }
77}
78
79fn push_linear(arena: &mut ExprArena, mut t: LinearTerms) -> ExprId {
81 t.coeffs.retain(|(_, c)| *c != 0.0);
82 arena.push(ExprNode::Linear { coeffs: t.coeffs, constant: t.constant })
83}
84
85pub(crate) fn add_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
88 if let (Some(lt), Some(rt)) = (as_linear(arena, lhs, false), as_linear(arena, rhs, false)) {
89 let mut map: FxHashMap<VarId, f64> =
90 FxHashMap::with_capacity_and_hasher(lt.coeffs.len() + rt.coeffs.len(), FxBuildHasher);
91 for (v, c) in lt.coeffs.into_iter().chain(rt.coeffs) {
92 *map.entry(v).or_insert(0.0) += c;
93 }
94 return push_linear(
95 arena,
96 LinearTerms { coeffs: map.into_iter().collect(), constant: lt.constant + rt.constant },
97 );
98 }
99 arena.push(ExprNode::Add(smallvec![lhs, rhs]))
100}
101
102pub(crate) fn add_n(arena: &mut ExprArena, ids: &[ExprId]) -> ExprId {
109 match ids {
110 [] => panic!("add_n on an empty term list"),
111 [one] => *one,
112 _ => arena.push(ExprNode::Add(ids.iter().copied().collect())),
113 }
114}
115
116pub(crate) fn sub_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
118 let neg = neg_into(arena, rhs);
119 add_into(arena, lhs, neg)
120}
121
122pub(crate) fn mul_into(arena: &mut ExprArena, lhs: ExprId, rhs: ExprId) -> ExprId {
125 if let ExprNode::Const(c) = *arena.get(lhs) {
126 if let Some(mut t) = as_linear(arena, rhs, false) {
127 t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
128 t.constant *= c;
129 return push_linear(arena, t);
130 }
131 }
132 if let ExprNode::Const(c) = *arena.get(rhs) {
133 if let Some(mut t) = as_linear(arena, lhs, false) {
134 t.coeffs.iter_mut().for_each(|(_, co)| *co *= c);
135 t.constant *= c;
136 return push_linear(arena, t);
137 }
138 }
139 arena.push(ExprNode::Mul(smallvec![lhs, rhs]))
140}
141
142pub(crate) fn div_into(arena: &mut ExprArena, num: ExprId, den: ExprId) -> ExprId {
146 if let ExprNode::Const(c) = *arena.get(den) {
147 if c != 0.0 {
148 if let Some(mut t) = as_linear(arena, num, false) {
149 let inv = 1.0 / c;
150 t.coeffs.iter_mut().for_each(|(_, co)| *co *= inv);
151 t.constant *= inv;
152 return push_linear(arena, t);
153 }
154 let inv = arena.push(ExprNode::Const(1.0 / c));
155 return mul_into(arena, num, inv);
156 }
157 }
158 arena.push(ExprNode::Div(num, den))
159}
160
161pub(crate) fn neg_into(arena: &mut ExprArena, rhs: ExprId) -> ExprId {
163 if let Some(mut t) = as_linear(arena, rhs, false) {
164 t.coeffs.iter_mut().for_each(|(_, c)| *c = -*c);
165 t.constant = -t.constant;
166 return push_linear(arena, t);
167 }
168 arena.push(ExprNode::Neg(rhs))
169}
170
171pub fn extract_linear(arena: &ExprArena, id: ExprId) -> Option<LinearTerms> {
179 as_linear(arena, id, true)
180}
181
182#[derive(Copy, Clone, Debug, PartialEq, Eq)]
186pub struct SignedExpr {
187 pub id: ExprId,
188 pub neg: bool,
189}
190
191pub fn split_linear(arena: &ExprArena, id: ExprId) -> (LinearTerms, Vec<SignedExpr>) {
203 if let Some(lt) = as_linear(arena, id, true) {
204 return (lt, Vec::new());
205 }
206 let mut lin = LinearTerms::default();
207 let mut residual: Vec<SignedExpr> = Vec::new();
208 let mut sign_stack: smallvec::SmallVec<[(ExprId, f64); 8]> = smallvec![(id, 1.0)];
209 while let Some((cur, sign)) = sign_stack.pop() {
210 match arena.get(cur) {
211 ExprNode::Add(children) => {
212 for c in children.iter().copied() {
213 sign_stack.push((c, sign));
214 }
215 }
216 ExprNode::Neg(inner) => sign_stack.push((*inner, -sign)),
217 _ => {
218 if let Some(mut t) = as_linear(arena, cur, true) {
219 if (sign - 1.0).abs() > 0.0 {
220 t.coeffs.iter_mut().for_each(|(_, c)| *c *= sign);
221 t.constant *= sign;
222 }
223 for (v, c) in t.coeffs {
224 if let Some((_, acc)) = lin.coeffs.iter_mut().find(|(vv, _)| *vv == v) {
225 *acc += c;
226 } else {
227 lin.coeffs.push((v, c));
228 }
229 }
230 lin.constant += t.constant;
231 } else {
232 residual.push(SignedExpr { id: cur, neg: sign < 0.0 });
233 }
234 }
235 }
236 }
237 lin.coeffs.retain(|(_, c)| *c != 0.0);
238 (lin, residual)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244 use crate::arena::{ExprArena, ExprNode, VarId};
245
246 #[test]
247 fn param_times_var_stays_symbolic_until_extracted() {
248 let mut arena = ExprArena::new();
251 let pid = arena.new_param(3.0);
252 let price = arena.param(pid);
253 let xnode = arena.push(ExprNode::Var(VarId(0)));
254 let prod = mul_into(&mut arena, price, xnode);
255 assert!(matches!(arena.get(prod), ExprNode::Mul(_)));
256
257 let terms = extract_linear(&arena, prod).expect("linear");
258 assert_eq!(terms.coeffs, vec![(VarId(0), 3.0)]);
259 assert!(terms.constant.abs() < f64::EPSILON);
260 }
261
262 #[test]
263 fn rebinding_param_updates_extracted_coeff() {
264 let mut arena = ExprArena::new();
265 let pid = arena.new_param(3.0);
266 let price = arena.param(pid);
267 let xnode = arena.push(ExprNode::Var(VarId(0)));
268 let prod = mul_into(&mut arena, price, xnode);
269
270 arena.set_param_value(pid, 10.0);
271 let terms = extract_linear(&arena, prod).expect("linear");
272 assert_eq!(terms.coeffs, vec![(VarId(0), 10.0)]);
273 }
274
275 #[test]
276 fn param_plus_var_resolves_constant() {
277 let mut arena = ExprArena::new();
278 let pid = arena.new_param(5.0);
279 let price = arena.param(pid);
280 let xnode = arena.push(ExprNode::Var(VarId(0)));
281 let sum = add_into(&mut arena, price, xnode);
282 let terms = extract_linear(&arena, sum).expect("linear");
283 assert_eq!(terms.coeffs, vec![(VarId(0), 1.0)]);
284 assert!((terms.constant - 5.0).abs() < f64::EPSILON);
285 }
286}