use crate::scalar::bigrat::BigRat;
use crate::scalar::rat::Rat;
#[derive(Clone, Debug)]
pub enum Coeff {
Rat(Rat),
Big(BigRat),
}
impl Coeff {
pub const ZERO: Coeff = Coeff::Rat(Rat::ZERO);
pub const ONE: Coeff = Coeff::Rat(Rat::ONE);
#[inline]
pub fn is_zero(&self) -> bool {
match self {
Coeff::Rat(r) => r.is_zero(),
Coeff::Big(b) => b.is_zero(),
}
}
#[inline]
pub fn is_positive(&self) -> bool {
match self {
Coeff::Rat(r) => r.is_positive(),
Coeff::Big(b) => b.is_positive(),
}
}
#[inline]
pub fn is_negative(&self) -> bool {
match self {
Coeff::Rat(r) => r.is_negative(),
Coeff::Big(b) => b.is_negative(),
}
}
pub fn to_rat(&self) -> Option<Rat> {
match self {
Coeff::Rat(r) => Some(*r),
Coeff::Big(b) => b.to_rat(),
}
}
pub fn canonicalize(self) -> Self {
match self {
Coeff::Big(ref b) => {
if let Some(r) = b.to_rat() {
Coeff::Rat(r)
} else {
self
}
}
other => other,
}
}
pub fn from_rat(r: Rat) -> Self {
Coeff::Rat(r)
}
pub fn from_bigrat(b: BigRat) -> Self {
if let Some(r) = b.to_rat() {
Coeff::Rat(r)
} else {
Coeff::Big(b)
}
}
pub fn from_i64(n: i64) -> Self {
Coeff::Rat(Rat::from(n))
}
pub fn recip(self) -> Self {
match self {
Coeff::Rat(r) => {
if r.is_zero() {
panic!("Coeff: reciprocal of zero");
}
Coeff::Rat(r.recip())
}
Coeff::Big(b) => Coeff::from_bigrat(b.recip()),
}
}
pub fn abs(self) -> Self {
match self {
Coeff::Rat(r) => Coeff::Rat(r.abs()),
Coeff::Big(b) => Coeff::Big(b.abs()),
}
}
}
fn to_bigrat(c: &Coeff) -> BigRat {
match c {
Coeff::Rat(r) => BigRat::from_rat(*r),
Coeff::Big(b) => b.clone(),
}
}
fn big_op(a: &Coeff, b: &Coeff, f: impl FnOnce(BigRat, BigRat) -> BigRat) -> Coeff {
let result = f(to_bigrat(a), to_bigrat(b));
Coeff::from_bigrat(result)
}
impl std::ops::Neg for Coeff {
type Output = Self;
fn neg(self) -> Self {
match self {
Coeff::Rat(r) => Coeff::Rat(-r),
Coeff::Big(b) => Coeff::Big(b.neg()),
}
}
}
impl std::ops::Add for Coeff {
type Output = Self;
fn add(self, rhs: Self) -> Self {
match (&self, &rhs) {
(Coeff::Rat(a), Coeff::Rat(b)) => match a.checked_add(*b) {
Some(r) => Coeff::Rat(r),
None => big_op(&self, &rhs, |a, b| a.add(&b)),
},
_ => big_op(&self, &rhs, |a, b| a.add(&b)),
}
}
}
impl std::ops::Sub for Coeff {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
self + (-rhs)
}
}
impl std::ops::Mul for Coeff {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
match (&self, &rhs) {
(Coeff::Rat(a), Coeff::Rat(b)) => match a.checked_mul(*b) {
Some(r) => Coeff::Rat(r),
None => big_op(&self, &rhs, |a, b| a.mul(&b)),
},
_ => big_op(&self, &rhs, |a, b| a.mul(&b)),
}
}
}
impl std::ops::Div for Coeff {
type Output = Self;
fn div(self, rhs: Self) -> Self {
assert!(!rhs.is_zero(), "Coeff: division by zero");
match (&self, &rhs) {
(Coeff::Rat(a), Coeff::Rat(b)) => Coeff::Rat(*a / *b),
_ => big_op(&self, &rhs, |a, b| a.div(&b)),
}
}
}
impl std::ops::AddAssign for Coeff {
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs;
}
}
impl std::ops::SubAssign for Coeff {
fn sub_assign(&mut self, rhs: Self) {
*self = self.clone() - rhs;
}
}
impl std::ops::MulAssign for Coeff {
fn mul_assign(&mut self, rhs: Self) {
*self = self.clone() * rhs;
}
}
impl PartialEq for Coeff {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Coeff::Rat(a), Coeff::Rat(b)) => a == b,
(Coeff::Big(a), Coeff::Big(b)) => a == b,
_ => {
let diff = self.clone() - other.clone();
diff.is_zero()
}
}
}
}
impl Eq for Coeff {}
impl PartialOrd for Coeff {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Coeff {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
match (self, other) {
(Coeff::Rat(a), Coeff::Rat(b)) => a.cmp(b),
(Coeff::Big(a), Coeff::Big(b)) => a.cmp(b),
_ => {
let diff = self.clone() - other.clone();
if diff.is_positive() {
std::cmp::Ordering::Greater
} else if diff.is_negative() {
std::cmp::Ordering::Less
} else {
std::cmp::Ordering::Equal
}
}
}
}
}
impl From<Rat> for Coeff {
fn from(r: Rat) -> Self {
Coeff::Rat(r)
}
}
impl From<BigRat> for Coeff {
fn from(b: BigRat) -> Self {
Coeff::from_bigrat(b)
}
}
impl From<i64> for Coeff {
fn from(n: i64) -> Self {
Coeff::Rat(Rat::from(n))
}
}
impl Default for Coeff {
fn default() -> Self {
Coeff::ZERO
}
}
impl std::fmt::Display for Coeff {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Coeff::Rat(r) => write!(f, "{}", r),
Coeff::Big(b) => write!(f, "{}", b),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rat_stays_rat() {
let a = Coeff::from_i64(3) + Coeff::from_i64(4);
assert!(matches!(a, Coeff::Rat(_)));
assert_eq!(a.to_rat(), Some(Rat::from(7i64)));
}
#[test]
fn overflow_promotes_to_big() {
let big = Coeff::Rat(Rat::new(i128::MAX / 2 + 1, 1));
let result = big.clone() + big;
assert!(matches!(result, Coeff::Big(_)));
assert!(result.is_positive());
}
#[test]
fn big_demotes_when_small() {
let a = Coeff::Big(BigRat::from_i128(42));
let b = Coeff::Big(BigRat::from_i128(42));
let c = a - b;
assert!(c.is_zero());
}
#[test]
fn mixed_arithmetic() {
let rat = Coeff::from_i64(5);
let big = Coeff::Big(BigRat::from_i128(i128::MAX));
let result = rat + big;
assert!(result.is_positive());
}
#[test]
fn ordering() {
assert!(Coeff::from_i64(3) < Coeff::from_i64(5));
assert!(Coeff::from_i64(-1) < Coeff::from_i64(0));
}
}