use super::semiring::Tropical;
#[derive(Debug, Clone, Copy)]
pub struct TropicalMonomial {
pub coeff: f64,
pub exp: i32,
}
impl TropicalMonomial {
pub fn new(coeff: f64, exp: i32) -> Self {
Self { coeff, exp }
}
#[inline]
pub fn eval(&self, x: f64) -> f64 {
if self.coeff == f64::NEG_INFINITY {
f64::NEG_INFINITY
} else {
self.coeff + self.exp as f64 * x
}
}
pub fn mul(&self, other: &Self) -> Self {
Self {
coeff: self.coeff + other.coeff,
exp: self.exp + other.exp,
}
}
}
#[derive(Debug, Clone)]
pub struct TropicalPolynomial {
terms: Vec<TropicalMonomial>,
}
impl TropicalPolynomial {
pub fn from_coeffs(coeffs: &[f64]) -> Self {
let terms: Vec<TropicalMonomial> = coeffs
.iter()
.enumerate()
.filter(|(_, &c)| c != f64::NEG_INFINITY)
.map(|(i, &c)| TropicalMonomial::new(c, i as i32))
.collect();
Self { terms }
}
pub fn from_monomials(terms: Vec<TropicalMonomial>) -> Self {
let mut sorted = terms;
sorted.sort_by_key(|m| m.exp);
Self { terms: sorted }
}
pub fn num_terms(&self) -> usize {
self.terms.len()
}
pub fn eval(&self, x: f64) -> f64 {
self.terms
.iter()
.map(|m| m.eval(x))
.fold(f64::NEG_INFINITY, f64::max)
}
pub fn roots(&self) -> Vec<f64> {
if self.terms.len() < 2 {
return vec![];
}
let mut roots = Vec::new();
for i in 0..self.terms.len() - 1 {
for j in i + 1..self.terms.len() {
let m1 = &self.terms[i];
let m2 = &self.terms[j];
if m1.exp != m2.exp {
let x = (m1.coeff - m2.coeff) / (m2.exp - m1.exp) as f64;
let val = m1.eval(x);
let max_val = self.eval(x);
if (val - max_val).abs() < 1e-10 {
roots.push(x);
}
}
}
}
roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
roots.dedup_by(|a, b| (*a - *b).abs() < 1e-10);
roots
}
pub fn num_linear_regions(&self) -> usize {
1 + self.roots().len()
}
pub fn mul(&self, other: &Self) -> Self {
let mut new_terms = Vec::new();
for m1 in &self.terms {
for m2 in &other.terms {
new_terms.push(m1.mul(m2));
}
}
new_terms.sort_by_key(|m| m.exp);
let mut simplified = Vec::new();
let mut i = 0;
while i < new_terms.len() {
let exp = new_terms[i].exp;
let mut max_coeff = new_terms[i].coeff;
while i < new_terms.len() && new_terms[i].exp == exp {
max_coeff = max_coeff.max(new_terms[i].coeff);
i += 1;
}
simplified.push(TropicalMonomial::new(max_coeff, exp));
}
Self { terms: simplified }
}
pub fn add(&self, other: &Self) -> Self {
let mut combined: Vec<TropicalMonomial> = Vec::new();
combined.extend(self.terms.iter().cloned());
combined.extend(other.terms.iter().cloned());
combined.sort_by_key(|m| m.exp);
let mut simplified = Vec::new();
let mut i = 0;
while i < combined.len() {
let exp = combined[i].exp;
let mut max_coeff = combined[i].coeff;
while i < combined.len() && combined[i].exp == exp {
max_coeff = max_coeff.max(combined[i].coeff);
i += 1;
}
simplified.push(TropicalMonomial::new(max_coeff, exp));
}
Self { terms: simplified }
}
}
#[derive(Debug, Clone)]
pub struct MultivariateTropicalPolynomial {
nvars: usize,
terms: Vec<(f64, Vec<i32>)>,
}
impl MultivariateTropicalPolynomial {
pub fn new(nvars: usize, terms: Vec<(f64, Vec<i32>)>) -> Self {
Self { nvars, terms }
}
pub fn eval(&self, x: &[f64]) -> f64 {
assert_eq!(x.len(), self.nvars);
self.terms
.iter()
.map(|(coeff, exp)| {
if *coeff == f64::NEG_INFINITY {
f64::NEG_INFINITY
} else {
let linear: f64 = exp
.iter()
.zip(x.iter())
.map(|(&e, &xi)| e as f64 * xi)
.sum();
coeff + linear
}
})
.fold(f64::NEG_INFINITY, f64::max)
}
pub fn num_terms(&self) -> usize {
self.terms.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tropical_polynomial_eval() {
let p = TropicalPolynomial::from_coeffs(&[2.0, 1.0, -1.0]);
assert!((p.eval(0.0) - 2.0).abs() < 1e-10); assert!((p.eval(1.0) - 2.0).abs() < 1e-10); assert!((p.eval(3.0) - 5.0).abs() < 1e-10); }
#[test]
fn test_tropical_roots() {
let p = TropicalPolynomial::from_coeffs(&[0.0, 0.0]);
let roots = p.roots();
assert_eq!(roots.len(), 1);
assert!(roots[0].abs() < 1e-10);
}
#[test]
fn test_tropical_mul() {
let p = TropicalPolynomial::from_coeffs(&[1.0, 2.0]); let q = TropicalPolynomial::from_coeffs(&[0.0, 1.0]);
let pq = p.mul(&q);
assert!(pq.num_terms() > 0);
}
#[test]
fn test_multivariate() {
let p = MultivariateTropicalPolynomial::new(
2,
vec![(0.0, vec![0, 0]), (0.0, vec![1, 0]), (0.0, vec![0, 1])],
);
assert!((p.eval(&[1.0, 2.0]) - 2.0).abs() < 1e-10);
assert!((p.eval(&[3.0, 1.0]) - 3.0).abs() < 1e-10);
}
}