use std::collections::{BTreeMap, BTreeSet, HashMap};
use crate::syntax::names::GenericParamName;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NatOverflowError;
impl std::fmt::Display for NatOverflowError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"type-level Nat arithmetic overflow (values are stored as `u64`)"
)
}
}
impl std::error::Error for NatOverflowError {}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct Monomial(pub(crate) BTreeMap<GenericParamName, u64>);
impl Monomial {
#[must_use]
pub(crate) const fn constant() -> Self {
Self(BTreeMap::new())
}
#[must_use]
pub(crate) fn var(name: GenericParamName) -> Self {
let mut m = BTreeMap::new();
m.insert(name, 1);
Self(m)
}
#[must_use]
pub(crate) fn is_constant(&self) -> bool {
self.0.is_empty()
}
pub(crate) fn mul(&self, other: &Self) -> Result<Self, NatOverflowError> {
let mut result = self.0.clone();
for (var, exp) in &other.0 {
let entry = result.entry(var.clone()).or_insert(0);
*entry = entry.checked_add(*exp).ok_or(NatOverflowError)?;
}
Ok(Self(result))
}
#[must_use]
pub(crate) fn evaluate(&self, bindings: &HashMap<GenericParamName, u64>) -> Option<u64> {
let mut result: u64 = 1;
for (var, exp) in &self.0 {
let val = bindings.get(var)?;
result = result.checked_mul(val.checked_pow(u32::try_from(*exp).ok()?)?)?;
}
Some(result)
}
#[must_use]
pub(crate) fn substitute(
&self,
bindings: &HashMap<GenericParamName, u64>,
) -> Option<(Self, u64)> {
let mut remaining = BTreeMap::new();
let mut factor: u64 = 1;
for (var, exp) in &self.0 {
if let Some(val) = bindings.get(var) {
factor = factor.checked_mul(val.checked_pow(u32::try_from(*exp).ok()?)?)?;
} else {
remaining.insert(var.clone(), *exp);
}
}
Some((Self(remaining), factor))
}
#[must_use]
pub(crate) fn format(&self) -> String {
let mut parts = Vec::new();
for (var, exp) in &self.0 {
if *exp == 1 {
parts.push(var.to_string());
} else {
parts.push(format!("{var}^{exp}"));
}
}
parts.join(" * ")
}
}
impl PartialOrd for Monomial {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Monomial {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
let a: Vec<_> = self.0.iter().collect();
let b: Vec<_> = other.0.iter().collect();
a.cmp(&b)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct NatPolyForm {
pub(crate) terms: BTreeMap<Monomial, u64>,
}
pub type NatLinearForm = NatPolyForm;
impl NatPolyForm {
#[must_use]
pub fn from_constant(c: u64) -> Self {
let mut terms = BTreeMap::new();
if c != 0 {
terms.insert(Monomial::constant(), c);
}
Self { terms }
}
#[must_use]
pub fn from_var(name: GenericParamName) -> Self {
let mut terms = BTreeMap::new();
terms.insert(Monomial::var(name), 1);
Self { terms }
}
pub fn add(&self, other: &Self) -> Result<Self, NatOverflowError> {
let mut terms = self.terms.clone();
for (mono, coeff) in &other.terms {
let entry = terms.entry(mono.clone()).or_insert(0);
*entry = entry.checked_add(*coeff).ok_or(NatOverflowError)?;
}
terms.retain(|_, c| *c != 0);
Ok(Self { terms })
}
pub fn mul(&self, other: &Self) -> Result<Self, NatOverflowError> {
let mut terms: BTreeMap<Monomial, u64> = BTreeMap::new();
for (m1, c1) in &self.terms {
for (m2, c2) in &other.terms {
let mono = m1.mul(m2)?;
let term = c1.checked_mul(*c2).ok_or(NatOverflowError)?;
let entry = terms.entry(mono).or_insert(0);
*entry = entry.checked_add(term).ok_or(NatOverflowError)?;
}
}
terms.retain(|_, c| *c != 0);
Ok(Self { terms })
}
#[must_use]
pub fn constant(&self) -> u64 {
self.terms.get(&Monomial::constant()).copied().unwrap_or(0)
}
#[must_use]
pub fn is_constant(&self) -> bool {
self.terms.iter().all(|(m, _)| m.is_constant())
}
#[must_use]
pub fn evaluate(&self, bindings: &HashMap<GenericParamName, u64>) -> Option<u64> {
let mut result: u64 = 0;
for (mono, coeff) in &self.terms {
result = result.checked_add(coeff.checked_mul(mono.evaluate(bindings)?)?)?;
}
Some(result)
}
#[must_use]
pub fn format(&self) -> String {
if self.terms.is_empty() {
return "0".to_string();
}
let mut parts = Vec::new();
for (mono, coeff) in &self.terms {
if mono.is_constant() {
continue;
}
let mono_str = mono.format();
if *coeff == 1 {
parts.push(mono_str);
} else {
parts.push(format!("{coeff} * {mono_str}"));
}
}
if let Some(&c) = self.terms.get(&Monomial::constant())
&& (c > 0 || parts.is_empty())
{
parts.push(c.to_string());
}
if parts.is_empty() {
"0".to_string()
} else {
parts.join(" + ")
}
}
#[must_use]
pub fn is_leq(&self, other: &Self) -> bool {
self.terms.iter().all(|(mono, &coeff)| {
let other_coeff = other.terms.get(mono).copied().unwrap_or(0);
coeff <= other_coeff
})
}
#[must_use]
pub fn variables(&self) -> BTreeSet<GenericParamName> {
self.terms
.keys()
.flat_map(|mono| mono.0.keys().cloned())
.collect()
}
}
impl std::fmt::Display for NatPolyForm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.format())
}
}