use std::collections::BTreeMap;
use crate::term::{Literal, Term};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Monomial {
powers: BTreeMap<i64, u32>,
}
impl Monomial {
pub fn one() -> Self {
Monomial {
powers: BTreeMap::new(),
}
}
pub fn var(index: i64) -> Self {
let mut powers = BTreeMap::new();
powers.insert(index, 1);
Monomial { powers }
}
pub fn mul(&self, other: &Monomial) -> Monomial {
let mut result = self.powers.clone();
for (var, exp) in &other.powers {
*result.entry(*var).or_insert(0) += exp;
}
Monomial { powers: result }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Polynomial {
terms: BTreeMap<Monomial, i64>,
}
impl Polynomial {
pub fn zero() -> Self {
Polynomial {
terms: BTreeMap::new(),
}
}
pub fn constant(c: i64) -> Self {
if c == 0 {
return Self::zero();
}
let mut terms = BTreeMap::new();
terms.insert(Monomial::one(), c);
Polynomial { terms }
}
pub fn var(index: i64) -> Self {
let mut terms = BTreeMap::new();
terms.insert(Monomial::var(index), 1);
Polynomial { terms }
}
pub fn add(&self, other: &Polynomial) -> Polynomial {
let mut result = self.terms.clone();
for (mono, coeff) in &other.terms {
let entry = result.entry(mono.clone()).or_insert(0);
*entry += coeff;
if *entry == 0 {
result.remove(mono);
}
}
Polynomial { terms: result }
}
pub fn neg(&self) -> Polynomial {
let mut result = BTreeMap::new();
for (mono, coeff) in &self.terms {
result.insert(mono.clone(), -coeff);
}
Polynomial { terms: result }
}
pub fn sub(&self, other: &Polynomial) -> Polynomial {
self.add(&other.neg())
}
pub fn mul(&self, other: &Polynomial) -> Polynomial {
let mut result = Polynomial::zero();
for (m1, c1) in &self.terms {
for (m2, c2) in &other.terms {
let mono = m1.mul(m2);
let coeff = c1 * c2;
let entry = result.terms.entry(mono).or_insert(0);
*entry += coeff;
}
}
result.terms.retain(|_, c| *c != 0);
result
}
pub fn canonical_eq(&self, other: &Polynomial) -> bool {
self.terms == other.terms
}
}
#[derive(Debug)]
pub enum ReifyError {
NonPolynomial(String),
MalformedTerm,
}
pub fn reify(term: &Term) -> Result<Polynomial, ReifyError> {
if let Some(n) = extract_slit(term) {
return Ok(Polynomial::constant(n));
}
if let Some(i) = extract_svar(term) {
return Ok(Polynomial::var(i));
}
if let Some(name) = extract_sname(term) {
let hash = name_to_var_index(&name);
return Ok(Polynomial::var(hash));
}
if let Some((op, a, b)) = extract_binary_app(term) {
match op.as_str() {
"add" => {
let pa = reify(&a)?;
let pb = reify(&b)?;
return Ok(pa.add(&pb));
}
"sub" => {
let pa = reify(&a)?;
let pb = reify(&b)?;
return Ok(pa.sub(&pb));
}
"mul" => {
let pa = reify(&a)?;
let pb = reify(&b)?;
return Ok(pa.mul(&pb));
}
"div" | "mod" => {
return Err(ReifyError::NonPolynomial(format!(
"Operation '{}' is not supported in ring",
op
)));
}
_ => {
return Err(ReifyError::NonPolynomial(format!(
"Unknown operation '{}'",
op
)));
}
}
}
Err(ReifyError::NonPolynomial(
"Unrecognized term structure".to_string(),
))
}
fn extract_slit(term: &Term) -> Option<i64> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SLit" {
if let Term::Lit(Literal::Int(n)) = arg.as_ref() {
return Some(*n);
}
}
}
}
None
}
fn extract_svar(term: &Term) -> Option<i64> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SVar" {
if let Term::Lit(Literal::Int(i)) = arg.as_ref() {
return Some(*i);
}
}
}
}
None
}
fn extract_sname(term: &Term) -> Option<String> {
if let Term::App(ctor, arg) = term {
if let Term::Global(name) = ctor.as_ref() {
if name == "SName" {
if let Term::Lit(Literal::Text(s)) = arg.as_ref() {
return Some(s.clone());
}
}
}
}
None
}
fn extract_binary_app(term: &Term) -> Option<(String, Term, Term)> {
if let Term::App(outer, b) = term {
if let Term::App(sapp_outer, inner) = outer.as_ref() {
if let Term::Global(ctor) = sapp_outer.as_ref() {
if ctor == "SApp" {
if let Term::App(partial, a) = inner.as_ref() {
if let Term::App(sapp_inner, op_term) = partial.as_ref() {
if let Term::Global(ctor2) = sapp_inner.as_ref() {
if ctor2 == "SApp" {
if let Some(op) = extract_sname(op_term) {
return Some((
op,
a.as_ref().clone(),
b.as_ref().clone(),
));
}
}
}
}
}
}
}
}
}
None
}
fn name_to_var_index(name: &str) -> i64 {
let hash: i64 = name
.bytes()
.fold(0i64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as i64));
-(hash.abs() + 1_000_000) }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_polynomial_constant() {
let p = Polynomial::constant(42);
assert_eq!(p, Polynomial::constant(42));
}
#[test]
fn test_polynomial_add() {
let x = Polynomial::var(0);
let y = Polynomial::var(1);
let sum1 = x.add(&y);
let sum2 = y.add(&x);
assert!(sum1.canonical_eq(&sum2), "x+y should equal y+x");
}
#[test]
fn test_polynomial_mul() {
let x = Polynomial::var(0);
let y = Polynomial::var(1);
let prod1 = x.mul(&y);
let prod2 = y.mul(&x);
assert!(prod1.canonical_eq(&prod2), "x*y should equal y*x");
}
#[test]
fn test_polynomial_distributivity() {
let x = Polynomial::var(0);
let y = Polynomial::var(1);
let z = Polynomial::var(2);
let lhs = x.mul(&y.add(&z));
let rhs = x.mul(&y).add(&x.mul(&z));
assert!(lhs.canonical_eq(&rhs));
}
#[test]
fn test_polynomial_subtraction() {
let x = Polynomial::var(0);
let result = x.sub(&x);
assert!(result.canonical_eq(&Polynomial::zero()));
}
#[test]
fn test_collatz_algebra() {
let k = Polynomial::var(0);
let two = Polynomial::constant(2);
let three = Polynomial::constant(3);
let one = Polynomial::constant(1);
let four = Polynomial::constant(4);
let six = Polynomial::constant(6);
let two_k = two.mul(&k);
let two_k_plus_1 = two_k.add(&one);
let three_times = three.mul(&two_k_plus_1);
let lhs = three_times.add(&one);
let six_k = six.mul(&k);
let rhs = six_k.add(&four);
assert!(lhs.canonical_eq(&rhs), "3(2k+1)+1 should equal 6k+4");
}
}