use crate::{
macros::*,
v1::{linear::Term, Linear, Quadratic, SampledValues, Samples, State},
Evaluate, VariableID, VariableIDSet,
};
use anyhow::{Context, Result};
use approx::AbsDiffEq;
use num::Zero;
use std::{collections::BTreeMap, fmt, iter::Sum, ops::*};
impl Zero for Linear {
fn zero() -> Self {
Self::from(0.0)
}
fn is_zero(&self) -> bool {
self.terms.is_empty() && self.constant == 0.0
}
}
impl Linear {
pub fn new(terms: impl Iterator<Item = (u64, f64)>, constant: f64) -> Self {
let mut merged = BTreeMap::new();
for (id, coefficient) in terms {
let v: &mut f64 = merged.entry(id).or_default();
*v += coefficient;
if v.abs() <= f64::EPSILON {
merged.remove(&id);
}
}
Self {
terms: merged
.into_iter()
.map(|(id, coefficient)| Term { id, coefficient })
.collect(),
constant,
}
}
pub fn single_term(id: u64, coefficient: f64) -> Self {
Self {
terms: vec![Term { id, coefficient }],
constant: 0.0,
}
}
pub fn degree(&self) -> u32 {
if self.terms.is_empty() {
0
} else {
1
}
}
pub fn as_constant(self) -> Option<f64> {
if self.terms.is_empty() {
Some(self.constant)
} else {
None
}
}
}
impl From<u64> for Linear {
fn from(id: u64) -> Self {
Self::single_term(id, 1.0)
}
}
impl From<f64> for Linear {
fn from(constant: f64) -> Self {
Self {
terms: vec![],
constant,
}
}
}
impl FromIterator<(u64, f64)> for Linear {
fn from_iter<I: IntoIterator<Item = (u64, f64)>>(iter: I) -> Self {
Self::new(iter.into_iter(), 0.0)
}
}
impl FromIterator<(Option<u64>, f64)> for Linear {
fn from_iter<I: IntoIterator<Item = (Option<u64>, f64)>>(iter: I) -> Self {
let mut map = BTreeMap::new();
for (id, coefficient) in iter {
*map.entry(id).or_default() += coefficient;
}
let mut out = Linear::default();
for (id, coefficient) in map {
if let Some(id) = id {
out.terms.push(Term { id, coefficient });
} else {
out.constant += coefficient;
}
}
out
}
}
impl<'a> IntoIterator for &'a Linear {
type Item = (Option<u64>, f64);
type IntoIter = Box<dyn Iterator<Item = Self::Item> + 'a>;
fn into_iter(self) -> Self::IntoIter {
Box::new(
self.terms
.iter()
.map(|term| (Some(term.id), term.coefficient))
.chain(std::iter::once((None, self.constant)))
.filter(|(_, c)| !c.is_zero()),
)
}
}
impl Add for Linear {
type Output = Self;
fn add(self, rhs: Self) -> Self {
let mut terms = BTreeMap::new();
for term in self.terms.iter().chain(rhs.terms.iter()) {
let value: &mut f64 = terms.entry(term.id).or_default();
*value += term.coefficient;
if value.abs() <= f64::EPSILON {
terms.remove(&term.id);
}
}
Self {
terms: terms
.into_iter()
.map(|(id, coefficient)| Term { id, coefficient })
.collect(),
constant: self.constant + rhs.constant,
}
}
}
impl Add<f64> for Linear {
type Output = Self;
fn add(self, rhs: f64) -> Self {
Self {
terms: self.terms,
constant: self.constant + rhs,
}
}
}
impl_add_inverse!(f64, Linear);
impl_sub_by_neg_add!(Linear, f64);
impl_sub_by_neg_add!(Linear, Linear);
impl Sum for Linear {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.fold(Linear::from(0), Add::add)
}
}
impl Mul<f64> for Linear {
type Output = Self;
fn mul(mut self, rhs: f64) -> Self {
if rhs.is_zero() {
return Linear::zero();
}
for term in &mut self.terms {
term.coefficient *= rhs;
}
self.constant *= rhs;
self
}
}
impl_mul_inverse!(f64, Linear);
impl_neg_by_mul!(Linear);
impl Mul for Linear {
type Output = Quadratic;
fn mul(self, rhs: Self) -> Quadratic {
let mut terms = BTreeMap::new();
for a in &self.terms {
for b in &rhs.terms {
let (row, col) = if a.id < b.id {
(a.id, b.id)
} else {
(b.id, a.id)
};
*terms.entry((row, col)).or_default() += a.coefficient * b.coefficient;
}
}
let mut quad: Quadratic = terms.into_iter().collect();
let c = self.constant;
let r = rhs.constant;
quad.linear = Some(self * r + c * rhs - r * c);
quad
}
}
impl AbsDiffEq for Linear {
type Epsilon = crate::ATol;
fn default_epsilon() -> Self::Epsilon {
crate::ATol::default()
}
fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
if !self.constant.abs_diff_eq(&other.constant, *epsilon)
|| self.terms.len() != other.terms.len()
{
return false;
}
let sub = self.clone() - other.clone();
sub.terms
.iter()
.all(|term| term.coefficient.abs() <= *epsilon)
}
}
impl fmt::Display for Linear {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_zero() {
return write!(f, "0");
}
crate::format::format_polynomial(
f,
self.into_iter()
.map(|(id, c)| (id.into_iter().map(VariableID::from).collect(), c)),
)
}
}
impl Evaluate for Linear {
type Output = f64;
type SampledOutput = SampledValues;
fn evaluate(&self, solution: &State, _atol: crate::ATol) -> Result<f64> {
let mut sum = self.constant;
for Term { id, coefficient } in &self.terms {
let s = solution
.entries
.get(id)
.with_context(|| format!("Variable id ({id}) is not found in the solution"))?;
sum += coefficient * s;
}
Ok(sum)
}
fn partial_evaluate(&mut self, state: &State, _atol: crate::ATol) -> Result<()> {
let mut i = 0;
while i < self.terms.len() {
let Term { id, coefficient } = self.terms[i];
if let Some(value) = state.entries.get(&id) {
self.constant += coefficient * value;
self.terms.swap_remove(i);
} else {
i += 1;
}
}
Ok(())
}
fn evaluate_samples(
&self,
samples: &Samples,
atol: crate::ATol,
) -> Result<Self::SampledOutput> {
let out = samples.map(|s| {
let value = self.evaluate(s, atol)?;
Ok(value)
})?;
Ok(out)
}
fn required_ids(&self) -> VariableIDSet {
self.terms
.iter()
.map(|term| VariableID::from(term.id))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
test_algebraic!(super::Linear);
#[test]
fn format() {
let linear = super::Linear::new(
[(1, 1.0), (2, -1.0), (3, -2.0), (4, 1.0 / 3.0)].into_iter(),
3.0,
);
assert_eq!(
linear.to_string(),
"x1 - x2 - 2*x3 + 0.3333333333333333*x4 + 3"
);
assert_eq!(format!("{linear:.2}"), "x1 - x2 - 2.00*x3 + 0.33*x4 + 3.00");
assert_eq!(super::Linear::zero().to_string(), "0");
let linear = super::Linear::new([(1, -1.0)].into_iter(), 0.0);
assert_eq!(linear.to_string(), "-x1");
let linear = super::Linear::new([(1, 1.0)].into_iter(), 1.0);
assert_eq!(linear.to_string(), "x1 + 1");
assert_eq!(format!("{linear:.2}"), "x1 + 1.00");
let linear = super::Linear::new([(1, 1.0)].into_iter(), -1.0);
assert_eq!(linear.to_string(), "x1 - 1");
assert_eq!(format!("{linear:.2}"), "x1 - 1.00");
}
}