1use crate::{
2 macros::*,
3 v1::{
4 function::{self, Function as FunctionEnum},
5 Function, Linear, Polynomial, Quadratic, SampledValues, Samples, State,
6 },
7 Bound, Bounds, Evaluate, MonomialDyn, VariableID, VariableIDSet,
8};
9use anyhow::{Context, Result};
10use approx::AbsDiffEq;
11use num::{
12 integer::{gcd, lcm},
13 Zero,
14};
15use std::{collections::HashMap, fmt, iter::*, ops::*};
16
17impl Zero for Function {
18 fn zero() -> Self {
19 Self {
20 function: Some(function::Function::Constant(0.0)),
21 }
22 }
23
24 fn is_zero(&self) -> bool {
25 match &self.function {
26 Some(FunctionEnum::Constant(c)) => c.is_zero(),
27 Some(FunctionEnum::Linear(linear)) => linear.is_zero(),
28 Some(FunctionEnum::Quadratic(quadratic)) => quadratic.is_zero(),
29 Some(FunctionEnum::Polynomial(poly)) => poly.is_zero(),
30 _ => false,
31 }
32 }
33}
34
35impl From<function::Function> for Function {
36 fn from(f: function::Function) -> Self {
37 Self { function: Some(f) }
38 }
39}
40
41impl From<Linear> for Function {
42 fn from(linear: Linear) -> Self {
43 Self {
44 function: Some(function::Function::Linear(linear)),
45 }
46 }
47}
48
49impl From<Quadratic> for Function {
50 fn from(q: Quadratic) -> Self {
51 Self {
52 function: Some(function::Function::Quadratic(q)),
53 }
54 }
55}
56
57impl From<Polynomial> for Function {
58 fn from(poly: Polynomial) -> Self {
59 Self {
60 function: Some(function::Function::Polynomial(poly)),
61 }
62 }
63}
64
65impl From<f64> for Function {
66 fn from(f: f64) -> Self {
67 Self {
68 function: Some(function::Function::Constant(f)),
69 }
70 }
71}
72
73impl FromIterator<(u64, f64)> for Function {
74 fn from_iter<I: IntoIterator<Item = (u64, f64)>>(iter: I) -> Self {
75 let linear: Linear = iter.into_iter().collect();
76 linear.into()
77 }
78}
79
80impl FromIterator<((u64, u64), f64)> for Function {
81 fn from_iter<I: IntoIterator<Item = ((u64, u64), f64)>>(iter: I) -> Self {
82 let quad: Quadratic = iter.into_iter().collect();
83 quad.into()
84 }
85}
86
87impl FromIterator<(MonomialDyn, f64)> for Function {
88 fn from_iter<I: IntoIterator<Item = (MonomialDyn, f64)>>(iter: I) -> Self {
89 let poly: Polynomial = iter.into_iter().collect();
90 poly.into()
91 }
92}
93
94impl<'a> IntoIterator for &'a Function {
95 type Item = (MonomialDyn, f64);
96 type IntoIter = Box<dyn Iterator<Item = Self::Item> + 'a>;
97
98 fn into_iter(self) -> Self::IntoIter {
99 match &self.function {
100 Some(FunctionEnum::Constant(c)) => {
101 Box::new(std::iter::once((MonomialDyn::empty(), *c)))
102 }
103 Some(FunctionEnum::Linear(linear)) => Box::new(
104 linear
105 .into_iter()
106 .map(|(id, c)| (id.map(VariableID::from).into(), c)),
107 ),
108 Some(FunctionEnum::Quadratic(quad)) => Box::new(quad.into_iter()),
109 Some(FunctionEnum::Polynomial(poly)) => Box::new(poly.into_iter()),
110 None => Box::new(std::iter::empty()),
111 }
112 }
113}
114
115impl Function {
116 pub fn degree(&self) -> u32 {
117 match &self.function {
118 Some(FunctionEnum::Constant(_)) => 0,
119 Some(FunctionEnum::Linear(linear)) => linear.degree(),
120 Some(FunctionEnum::Quadratic(quad)) => quad.degree(),
121 Some(FunctionEnum::Polynomial(poly)) => poly.degree(),
122 None => 0,
123 }
124 }
125
126 pub fn as_linear(self) -> Option<Linear> {
127 match self.function? {
128 FunctionEnum::Constant(c) => Some(Linear::from(c)),
129 FunctionEnum::Linear(linear) => Some(linear),
130 FunctionEnum::Quadratic(quadratic) => quadratic.as_linear(),
131 FunctionEnum::Polynomial(poly) => poly.as_linear(),
132 }
133 }
134
135 pub fn as_constant(self) -> Option<f64> {
136 match self.function? {
137 FunctionEnum::Constant(c) => Some(c),
138 FunctionEnum::Linear(linear) => linear.as_constant(),
139 FunctionEnum::Quadratic(quadratic) => quadratic.as_constant(),
140 FunctionEnum::Polynomial(poly) => poly.as_constant(),
141 }
142 }
143
144 pub fn get_constant(&self) -> f64 {
146 match &self.function {
147 Some(FunctionEnum::Constant(c)) => *c,
148 Some(FunctionEnum::Linear(linear)) => linear.constant,
149 Some(FunctionEnum::Quadratic(quad)) => quad.get_constant(),
150 Some(FunctionEnum::Polynomial(poly)) => poly.get_constant(),
151 None => 0.0,
152 }
153 }
154
155 pub fn substitute(&self, replacements: &HashMap<u64, Self>) -> Result<Self> {
160 if replacements.is_empty() {
161 return Ok(self.clone());
162 }
163 let mut out = Function::zero();
164 for (ids, coefficient) in self {
165 let mut v = Function::from(coefficient);
166 for id in ids.iter() {
167 if let Some(replacement) = replacements.get(id) {
168 v = v * replacement.clone();
169 } else {
170 v = v * Linear::single_term(id.into_inner(), 1.0);
171 }
172 }
173 out = out + v;
174 }
175 Ok(out)
176 }
177
178 pub fn evaluate_bound(&self, bounds: &Bounds) -> Bound {
179 let mut bound = Bound::zero();
180 for (ids, value) in self.into_iter() {
181 if value.is_zero() {
182 continue;
183 }
184 if ids.is_empty() {
185 bound += value;
186 continue;
187 }
188 let mut cur = Bound::new(1.0, 1.0).unwrap();
189 for (id, exp) in ids.chunks() {
190 let b = bounds.get(&id).cloned().unwrap_or_default();
191 cur *= b.pow(exp as u8);
192 if cur == Bound::default() {
193 return Bound::default();
194 }
195 }
196 bound += value * cur;
197 }
198 bound
199 }
200
201 pub fn content_factor(&self) -> Result<f64> {
205 let mut numer_gcd = 0;
206 let mut denom_lcm: i64 = 1;
207 for (_, coefficient) in self {
208 let r = num::Rational64::approximate_float(coefficient)
209 .context("Cannot approximate coefficient in 64-bit rational")?;
210 numer_gcd = gcd(numer_gcd, *r.numer());
211 denom_lcm
212 .checked_mul(*r.denom())
213 .context("Overflow detected while evaluating minimal integer coefficient multiplier. This means it is hard to make the all coefficient integer")?;
214 denom_lcm = lcm(denom_lcm, *r.denom());
215 }
216
217 if numer_gcd == 0 {
218 Ok(1.0)
219 } else {
220 Ok((denom_lcm as f64 / numer_gcd as f64).abs())
221 }
222 }
223}
224
225impl Add for Function {
226 type Output = Self;
227
228 fn add(self, rhs: Self) -> Self {
229 let lhs = self.function.expect("Empty Function");
230 let rhs = rhs.function.expect("Empty Function");
231 match (lhs, rhs) {
232 (FunctionEnum::Constant(lhs), FunctionEnum::Constant(rhs)) => Function::from(lhs + rhs),
233 (FunctionEnum::Linear(lhs), FunctionEnum::Constant(rhs))
235 | (FunctionEnum::Constant(rhs), FunctionEnum::Linear(lhs)) => Function::from(lhs + rhs),
236 (FunctionEnum::Linear(lhs), FunctionEnum::Linear(rhs)) => Function::from(lhs + rhs),
237 (FunctionEnum::Quadratic(lhs), FunctionEnum::Constant(rhs))
239 | (FunctionEnum::Constant(rhs), FunctionEnum::Quadratic(lhs)) => {
240 Function::from(lhs + rhs)
241 }
242 (FunctionEnum::Quadratic(lhs), FunctionEnum::Linear(rhs))
243 | (FunctionEnum::Linear(rhs), FunctionEnum::Quadratic(lhs)) => {
244 Function::from(lhs + rhs)
245 }
246 (FunctionEnum::Quadratic(lhs), FunctionEnum::Quadratic(rhs)) => {
247 Function::from(lhs + rhs)
248 }
249 (FunctionEnum::Polynomial(lhs), FunctionEnum::Constant(rhs))
251 | (FunctionEnum::Constant(rhs), FunctionEnum::Polynomial(lhs)) => {
252 Function::from(lhs + rhs)
253 }
254 (FunctionEnum::Polynomial(lhs), FunctionEnum::Linear(rhs))
255 | (FunctionEnum::Linear(rhs), FunctionEnum::Polynomial(lhs)) => {
256 Function::from(lhs + rhs)
257 }
258 (FunctionEnum::Polynomial(lhs), FunctionEnum::Quadratic(rhs))
259 | (FunctionEnum::Quadratic(rhs), FunctionEnum::Polynomial(lhs)) => {
260 Function::from(lhs + rhs)
261 }
262 (FunctionEnum::Polynomial(lhs), FunctionEnum::Polynomial(rhs)) => {
263 Function::from(lhs + rhs)
264 }
265 }
266 }
267}
268
269impl_add_from!(Function, f64);
270impl_add_from!(Function, Linear);
271impl_add_from!(Function, Quadratic);
272impl_add_from!(Function, Polynomial);
273impl_add_inverse!(f64, Function);
274impl_add_inverse!(Linear, Function);
275impl_add_inverse!(Quadratic, Function);
276impl_add_inverse!(Polynomial, Function);
277impl_sub_by_neg_add!(Function, Function);
278impl_sub_by_neg_add!(Function, f64);
279impl_sub_by_neg_add!(Function, Linear);
280impl_sub_by_neg_add!(Function, Quadratic);
281impl_sub_by_neg_add!(Function, Polynomial);
282
283impl Mul for Function {
284 type Output = Self;
285
286 fn mul(self, rhs: Self) -> Self {
287 let lhs = self.function.expect("Empty Function");
288 let rhs = rhs.function.expect("Empty Function");
289 match (lhs, rhs) {
290 (FunctionEnum::Constant(lhs), FunctionEnum::Constant(rhs)) => Function::from(lhs * rhs),
291 (FunctionEnum::Linear(lhs), FunctionEnum::Constant(rhs))
292 | (FunctionEnum::Constant(rhs), FunctionEnum::Linear(lhs)) => Function::from(lhs * rhs),
293 (FunctionEnum::Linear(lhs), FunctionEnum::Linear(rhs)) => Function::from(lhs * rhs),
294 (FunctionEnum::Quadratic(lhs), FunctionEnum::Constant(rhs))
295 | (FunctionEnum::Constant(rhs), FunctionEnum::Quadratic(lhs)) => {
296 Function::from(lhs * rhs)
297 }
298 (FunctionEnum::Quadratic(lhs), FunctionEnum::Linear(rhs))
299 | (FunctionEnum::Linear(rhs), FunctionEnum::Quadratic(lhs)) => {
300 Function::from(lhs * rhs)
301 }
302 (FunctionEnum::Quadratic(lhs), FunctionEnum::Quadratic(rhs)) => {
303 Function::from(lhs * rhs)
304 }
305 (FunctionEnum::Polynomial(lhs), FunctionEnum::Constant(rhs))
306 | (FunctionEnum::Constant(rhs), FunctionEnum::Polynomial(lhs)) => {
307 Function::from(lhs * rhs)
308 }
309 (FunctionEnum::Polynomial(lhs), FunctionEnum::Linear(rhs))
310 | (FunctionEnum::Linear(rhs), FunctionEnum::Polynomial(lhs)) => {
311 Function::from(lhs * rhs)
312 }
313 (FunctionEnum::Polynomial(lhs), FunctionEnum::Quadratic(rhs))
314 | (FunctionEnum::Quadratic(rhs), FunctionEnum::Polynomial(lhs)) => {
315 Function::from(lhs * rhs)
316 }
317 (FunctionEnum::Polynomial(lhs), FunctionEnum::Polynomial(rhs)) => {
318 Function::from(lhs * rhs)
319 }
320 }
321 }
322}
323
324impl_neg_by_mul!(Function);
325impl_mul_from!(Function, f64, Function);
326impl_mul_from!(Function, Linear, Function);
327impl_mul_from!(Function, Quadratic, Function);
328impl_mul_from!(Function, Polynomial, Function);
329impl_mul_inverse!(f64, Function);
330impl_mul_inverse!(Linear, Function);
331impl_mul_inverse!(Quadratic, Function);
332impl_mul_inverse!(Polynomial, Function);
333
334impl Sum for Function {
335 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
336 iter.fold(Function::from(0.0), |acc, x| acc + x)
337 }
338}
339
340impl Product for Function {
341 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
342 iter.fold(Function::from(1.0), |acc, x| acc * x)
343 }
344}
345
346impl AbsDiffEq for Function {
347 type Epsilon = crate::ATol;
348
349 fn default_epsilon() -> Self::Epsilon {
350 crate::ATol::default()
351 }
352
353 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
354 let lhs = self.function.as_ref().expect("Empty Function");
355 let rhs = other.function.as_ref().expect("Empty Function");
356 match (lhs, rhs) {
357 (FunctionEnum::Constant(lhs), FunctionEnum::Constant(rhs)) => {
359 lhs.abs_diff_eq(rhs, *epsilon)
360 }
361 (FunctionEnum::Linear(lhs), FunctionEnum::Linear(rhs)) => lhs.abs_diff_eq(rhs, epsilon),
362 (FunctionEnum::Quadratic(lhs), FunctionEnum::Quadratic(rhs)) => {
363 lhs.abs_diff_eq(rhs, epsilon)
364 }
365 (FunctionEnum::Polynomial(lhs), FunctionEnum::Polynomial(rhs)) => {
366 lhs.abs_diff_eq(rhs, epsilon)
367 }
368 (FunctionEnum::Constant(lhs), FunctionEnum::Linear(rhs))
370 | (FunctionEnum::Linear(rhs), FunctionEnum::Constant(lhs)) => {
371 let lhs = Linear::from(*lhs);
372 lhs.abs_diff_eq(rhs, epsilon)
373 }
374 (FunctionEnum::Constant(lhs), FunctionEnum::Quadratic(rhs))
375 | (FunctionEnum::Quadratic(rhs), FunctionEnum::Constant(lhs)) => {
376 let lhs = Quadratic::from(*lhs);
377 lhs.abs_diff_eq(rhs, epsilon)
378 }
379 (FunctionEnum::Constant(lhs), FunctionEnum::Polynomial(rhs))
380 | (FunctionEnum::Polynomial(rhs), FunctionEnum::Constant(lhs)) => {
381 let lhs = Polynomial::from(*lhs);
382 lhs.abs_diff_eq(rhs, epsilon)
383 }
384 (FunctionEnum::Linear(lhs), FunctionEnum::Quadratic(rhs))
385 | (FunctionEnum::Quadratic(rhs), FunctionEnum::Linear(lhs)) => {
386 let lhs = Quadratic::from(lhs.clone());
387 lhs.abs_diff_eq(rhs, epsilon)
388 }
389 (FunctionEnum::Linear(lhs), FunctionEnum::Polynomial(rhs))
390 | (FunctionEnum::Polynomial(rhs), FunctionEnum::Linear(lhs)) => {
391 let lhs = Polynomial::from(lhs.clone());
392 lhs.abs_diff_eq(rhs, epsilon)
393 }
394 (FunctionEnum::Quadratic(lhs), FunctionEnum::Polynomial(rhs))
395 | (FunctionEnum::Polynomial(rhs), FunctionEnum::Quadratic(lhs)) => {
396 let lhs = Polynomial::from(lhs.clone());
397 lhs.abs_diff_eq(rhs, epsilon)
398 }
399 }
400 }
401}
402
403impl fmt::Display for Function {
404 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
405 match &self.function {
406 Some(FunctionEnum::Constant(c)) => write!(f, "{c}"),
407 Some(FunctionEnum::Linear(linear)) => write!(f, "{linear}"),
408 Some(FunctionEnum::Quadratic(quadratic)) => write!(f, "{quadratic}"),
409 Some(FunctionEnum::Polynomial(poly)) => write!(f, "{poly}"),
410 None => write!(f, "0"),
411 }
412 }
413}
414
415impl Evaluate for Function {
416 type Output = f64;
417 type SampledOutput = SampledValues;
418
419 fn evaluate(&self, solution: &State, atol: crate::ATol) -> Result<f64> {
420 let out = match &self.function {
421 Some(FunctionEnum::Constant(c)) => *c,
422 Some(FunctionEnum::Linear(linear)) => linear.evaluate(solution, atol)?,
423 Some(FunctionEnum::Quadratic(quadratic)) => quadratic.evaluate(solution, atol)?,
424 Some(FunctionEnum::Polynomial(poly)) => poly.evaluate(solution, atol)?,
425 None => 0.0,
426 };
427 Ok(out)
428 }
429
430 fn partial_evaluate(&mut self, state: &State, atol: crate::ATol) -> Result<()> {
431 match &mut self.function {
432 Some(FunctionEnum::Linear(linear)) => linear.partial_evaluate(state, atol)?,
433 Some(FunctionEnum::Quadratic(quadratic)) => quadratic.partial_evaluate(state, atol)?,
434 Some(FunctionEnum::Polynomial(poly)) => poly.partial_evaluate(state, atol)?,
435 _ => {}
436 };
437 Ok(())
438 }
439
440 fn evaluate_samples(
441 &self,
442 samples: &Samples,
443 atol: crate::ATol,
444 ) -> Result<Self::SampledOutput> {
445 let out = samples.map(|s| {
446 let value = self.evaluate(s, atol)?;
447 Ok(value)
448 })?;
449 Ok(out)
450 }
451
452 fn required_ids(&self) -> VariableIDSet {
453 match &self.function {
454 Some(FunctionEnum::Linear(linear)) => linear.required_ids(),
455 Some(FunctionEnum::Quadratic(quadratic)) => quadratic.required_ids(),
456 Some(FunctionEnum::Polynomial(poly)) => poly.required_ids(),
457 _ => VariableIDSet::default(),
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use crate::{random::*, Evaluate, VariableID};
466 use maplit::*;
467
468 test_algebraic!(Function);
469
470 #[test]
471 fn evaluate_bound_missing() {
472 let f: Function = Linear::new([(1, 1.0), (2, 2.0)].into_iter(), 1.0).into();
473 let bounds = Bounds::default();
475 assert_eq!(f.evaluate_bound(&bounds), Bound::default());
476 }
477
478 #[test]
479 fn evaluate_bound() {
480 let x1 = Linear::single_term(1, 1.0);
481 let x2 = Linear::single_term(2, 2.0);
482 let f: Function = (x1.clone() + x2 + 1.0).into();
483 let bounds = btreemap! {
484 VariableID::from(1) => Bound::new(-1.0, 1.0).unwrap(),
485 VariableID::from(2) => Bound::new(2.0, 3.0).unwrap(),
486 };
487 insta::assert_debug_snapshot!(f.evaluate_bound(&bounds), @"Bound[4, 8]");
489
490 let f: Function = (x1.clone() * x1).into();
491 insta::assert_debug_snapshot!(f.evaluate_bound(&bounds), @"Bound[0, 1]");
493 insta::assert_debug_snapshot!(f.evaluate_bound(&Bounds::default()), @"Bound[0, inf)");
495 }
496
497 #[test]
498 fn content_factor() {
499 let x1 = Linear::single_term(1, 1.0);
500 let x2 = Linear::single_term(2, 1.0);
501
502 let f: Function = (x1.clone() + x2.clone()).into();
505 assert_eq!(f.content_factor().unwrap(), 1.0);
506
507 let f: Function = (0.5 * x1.clone() + (1.0 / 3.0) * x2.clone()).into();
510 assert_eq!(f.content_factor().unwrap(), 6.0);
511
512 let f: Function = (2.0 / 3.0 * x1.clone() + 2.0 / 5.0 * x2.clone()).into();
515 assert_eq!(f.content_factor().unwrap(), 15.0 / 2.0);
516
517 let f: Function = (3.0 / 4.0 * x1.clone() + 3.0 / 8.0 * x2.clone()).into();
520 assert_eq!(f.content_factor().unwrap(), 8.0 / 3.0);
521
522 use std::f64::consts::PI;
523 let f: Function = (PI * x1 + 2.0 * PI * x2).into();
524 assert_eq!(f.content_factor().unwrap(), 1.0 / PI,);
525 }
526
527 proptest! {
528 #[test]
529 fn test_as_linear_roundtrip(f in Function::arbitrary_with(FunctionParameters{ num_terms: 5, max_degree: 1, max_id: 10})) {
530 let linear = f.clone().as_linear().unwrap();
531 prop_assert!(f.abs_diff_eq(&Function::from(linear), crate::ATol::default()));
533 }
534
535 #[test]
536 fn test_as_constant_roundtrip(f in Function::arbitrary_with(FunctionParameters{ num_terms: 1, max_degree: 0, max_id: 10})) {
537 let c = f.clone().as_constant().unwrap();
538 prop_assert!(f.abs_diff_eq(&Function::from(c), crate::ATol::default()));
539 }
540
541 #[test]
542 fn test_max_degree_0(f in Function::arbitrary_with(FunctionParameters{ num_terms: 1, max_degree: 0, max_id: 10})) {
543 prop_assert!(f.degree() == 0);
544 }
545
546 #[test]
547 fn test_max_degree_1(f in Function::arbitrary_with(FunctionParameters{ num_terms: 5, max_degree: 1, max_id: 10})) {
548 prop_assert!(f.degree() <= 1);
549 }
550
551 #[test]
552 fn test_max_degree_2(f in Function::arbitrary_with(FunctionParameters{ num_terms: 5, max_degree: 2, max_id: 10})) {
553 prop_assert!(f.degree() <= 2);
554 }
555
556 #[test]
557 fn test_as_linear_any(f in Function::arbitrary()) {
558 prop_assert!((dbg!(f.degree()) >= 2) ^ dbg!(f.as_linear()).is_some());
559 }
560
561 #[test]
562 fn test_as_const_any(f in Function::arbitrary()) {
563 prop_assert!((dbg!(f.degree()) >= 1) ^ dbg!(f.as_constant()).is_some());
564 }
565
566 #[test]
567 fn evaluate_bound_arb(
568 (f, bounds, state) in Function::arbitrary()
569 .prop_flat_map(|f| {
570 let bounds = arbitrary_bounds(f.required_ids().into_iter());
571 (Just(f), bounds)
572 .prop_flat_map(|(f, bounds)| {
573 let state = arbitrary_state_within_bounds(&bounds, 1e5);
574 (Just(f), Just(bounds), state)
575 })
576 })
577 ) {
578 let bound = f.evaluate_bound(&bounds);
579 let value = f.evaluate(&state, crate::ATol::default()).unwrap();
580 prop_assert!(bound.contains(value, crate::ATol::default()));
581 }
582
583 #[test]
584 fn content_factor_arb(f in Function::arbitrary()) {
585 let Ok(multiplier) = f.content_factor() else { return Ok(()) };
586 prop_assert!(multiplier > 0.0);
587 let f = f * multiplier;
588 for (_, c) in &f {
589 if c.abs() > 1.0 {
590 prop_assert!((c - c.round()).abs() / c.abs() < 1e-10, "c = {c}");
591 } else {
592 prop_assert!((c - c.round()).abs() < 1e-10, "c = {c}");
593 }
594 }
595 }
596 }
597}