use crate::scalar::bigrat::BigRat;
use crate::scalar::radical::RadicalElement;
use crate::scalar::rat::Rat;
#[derive(Clone, Debug)]
pub enum Scalar {
Rat(Rat),
Big(BigRat),
Radical(RadicalElement),
}
impl Scalar {
pub fn is_zero(&self) -> bool {
match self {
Scalar::Rat(r) => r.is_zero(),
Scalar::Big(b) => b.is_zero(),
Scalar::Radical(r) => r.is_zero(),
}
}
pub fn is_positive(&self) -> bool {
match self {
Scalar::Rat(r) => r.is_positive(),
Scalar::Big(b) => b.is_positive(),
Scalar::Radical(r) => r.is_positive(),
}
}
pub fn is_negative(&self) -> bool {
match self {
Scalar::Rat(r) => r.is_negative(),
Scalar::Big(b) => b.is_negative(),
Scalar::Radical(r) => r.is_negative(),
}
}
pub fn is_big(&self) -> bool {
matches!(self, Scalar::Big(_))
}
pub fn try_as_rat(&self) -> Option<Rat> {
match self {
Scalar::Rat(r) => Some(*r),
Scalar::Big(b) => b.to_rat(),
Scalar::Radical(r) => r.to_rat(),
}
}
pub fn canonicalize(self) -> Self {
match self {
Scalar::Big(ref b) => {
if let Some(r) = b.to_rat() {
Scalar::Rat(r)
} else {
self
}
}
Scalar::Radical(r) => {
if let Some(rat) = r.to_rat() {
Scalar::Rat(rat)
} else {
Scalar::Radical(r)
}
}
other => other,
}
}
}
impl PartialEq for Scalar {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Scalar::Rat(a), Scalar::Rat(b)) => a == b,
(Scalar::Big(a), Scalar::Big(b)) => a == b,
(Scalar::Radical(a), Scalar::Radical(b)) => a == b,
(Scalar::Rat(a), Scalar::Big(b)) => b.to_rat().as_ref() == Some(a),
(Scalar::Big(a), Scalar::Rat(b)) => a.to_rat().as_ref() == Some(b),
(Scalar::Rat(a), Scalar::Radical(b)) => b.is_rational() && b.to_rat() == Some(*a),
(Scalar::Radical(a), Scalar::Rat(b)) => a.is_rational() && a.to_rat() == Some(*b),
_ => {
let diff = self.clone() - other.clone();
diff.is_zero()
}
}
}
}
impl Eq for Scalar {}
impl PartialOrd for Scalar {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Scalar {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (self, other) {
(Scalar::Rat(a), Scalar::Rat(b)) => return a.cmp(b),
(Scalar::Big(a), Scalar::Big(b)) => return a.cmp(b),
_ => {}
}
let diff = self.clone() - other.clone();
if diff.is_positive() {
Ordering::Greater
} else if diff.is_negative() {
Ordering::Less
} else {
Ordering::Equal
}
}
}
impl std::ops::Neg for Scalar {
type Output = Self;
fn neg(self) -> Self {
match self {
Scalar::Rat(r) => Scalar::Rat(-r),
Scalar::Big(b) => Scalar::Big(b.neg()),
Scalar::Radical(r) => Scalar::Radical(r.neg()),
}
}
}
fn scalar_to_radical(s: Scalar) -> RadicalElement {
match s {
Scalar::Rat(r) => RadicalElement::from_rat(r),
Scalar::Big(b) => RadicalElement::from_bigrat(b),
Scalar::Radical(r) => r,
}
}
fn to_bigrat(s: &Scalar) -> BigRat {
match s {
Scalar::Rat(r) => BigRat::from_rat(*r),
Scalar::Big(b) => b.clone(),
_ => panic!("to_bigrat called on non-rational scalar"),
}
}
fn big_add(a: &Scalar, b: &Scalar) -> Scalar {
let result = to_bigrat(a).add(&to_bigrat(b));
if let Some(r) = result.to_rat() {
Scalar::Rat(r)
} else {
Scalar::Big(result)
}
}
fn big_mul(a: &Scalar, b: &Scalar) -> Scalar {
let result = to_bigrat(a).mul(&to_bigrat(b));
if let Some(r) = result.to_rat() {
Scalar::Rat(r)
} else {
Scalar::Big(result)
}
}
fn big_div(a: &Scalar, b: &Scalar) -> Scalar {
let result = to_bigrat(a).div(&to_bigrat(b));
if let Some(r) = result.to_rat() {
Scalar::Rat(r)
} else {
Scalar::Big(result)
}
}
impl std::ops::Add for Scalar {
type Output = Self;
fn add(self, rhs: Self) -> Self {
if matches!(&self, Scalar::Radical(_)) || matches!(&rhs, Scalar::Radical(_)) {
return Scalar::Radical(scalar_to_radical(self).add(&scalar_to_radical(rhs)))
.canonicalize();
}
if self.is_big() || rhs.is_big() {
return big_add(&self, &rhs);
}
match (self, rhs) {
(Scalar::Rat(a), Scalar::Rat(b)) => match a.checked_add(b) {
Some(r) => Scalar::Rat(r),
None => big_add(&Scalar::Rat(a), &Scalar::Rat(b)),
},
_ => unreachable!(),
}
}
}
impl std::ops::Sub for Scalar {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
self + (-rhs)
}
}
impl std::ops::Mul for Scalar {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
if matches!(&self, Scalar::Radical(_)) || matches!(&rhs, Scalar::Radical(_)) {
return Scalar::Radical(scalar_to_radical(self).mul(&scalar_to_radical(rhs)))
.canonicalize();
}
if self.is_big() || rhs.is_big() {
return big_mul(&self, &rhs);
}
match (self, rhs) {
(Scalar::Rat(a), Scalar::Rat(b)) => match a.checked_mul(b) {
Some(r) => Scalar::Rat(r),
None => big_mul(&Scalar::Rat(a), &Scalar::Rat(b)),
},
_ => unreachable!(),
}
}
}
impl std::ops::Div for Scalar {
type Output = Self;
fn div(self, rhs: Self) -> Self {
if matches!(&self, Scalar::Radical(_)) || matches!(&rhs, Scalar::Radical(_)) {
return Scalar::Radical(scalar_to_radical(self).div(&scalar_to_radical(rhs)))
.canonicalize();
}
if self.is_big() || rhs.is_big() {
return big_div(&self, &rhs);
}
match (self, rhs) {
(Scalar::Rat(a), Scalar::Rat(b)) => Scalar::Rat(a / b),
_ => unreachable!(),
}
}
}
impl std::ops::AddAssign for Scalar {
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs;
}
}
impl std::ops::SubAssign for Scalar {
fn sub_assign(&mut self, rhs: Self) {
*self = self.clone() - rhs;
}
}
impl std::ops::MulAssign for Scalar {
fn mul_assign(&mut self, rhs: Self) {
*self = self.clone() * rhs;
}
}
impl From<Rat> for Scalar {
fn from(r: Rat) -> Self {
Scalar::Rat(r)
}
}
impl From<i64> for Scalar {
fn from(n: i64) -> Self {
Scalar::Rat(Rat::from(n))
}
}
impl From<BigRat> for Scalar {
fn from(b: BigRat) -> Self {
if let Some(r) = b.to_rat() {
Scalar::Rat(r)
} else {
Scalar::Big(b)
}
}
}
impl From<RadicalElement> for Scalar {
fn from(r: RadicalElement) -> Self {
if let Some(rat) = r.to_rat() {
Scalar::Rat(rat)
} else {
Scalar::Radical(r)
}
}
}
impl std::fmt::Display for Scalar {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Scalar::Rat(r) => write!(f, "{}", r),
Scalar::Big(b) => write!(f, "{}", b),
Scalar::Radical(r) => write!(f, "{}", r),
}
}
}
impl Default for Scalar {
fn default() -> Self {
Scalar::Rat(Rat::ZERO)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rat_stays_rat() {
let a = Scalar::from(1i64) + Scalar::from(2i64);
assert!(matches!(a, Scalar::Rat(_)));
assert_eq!(a.try_as_rat(), Some(Rat::from(3i64)));
}
#[test]
fn rat_promotes_to_radical() {
let a = Scalar::Rat(Rat::from(1i64));
let b = Scalar::Radical(RadicalElement::sqrt(Rat::from(2)));
let c = a + b;
assert!(matches!(c, Scalar::Radical(_)));
}
#[test]
fn radical_mul_demotes() {
let a = Scalar::Radical(RadicalElement::sqrt(Rat::from(2)));
let c = a.clone() * a;
assert_eq!(c.try_as_rat(), Some(Rat::from(2i64)));
}
#[test]
fn zero_detection() {
assert!(Scalar::from(0i64).is_zero());
assert!(!Scalar::from(1i64).is_zero());
assert!(!Scalar::from(-1i64).is_zero());
}
#[test]
fn positive_negative() {
assert!(Scalar::from(5i64).is_positive());
assert!(!Scalar::from(5i64).is_negative());
assert!(Scalar::from(-3i64).is_negative());
assert!(!Scalar::from(-3i64).is_positive());
assert!(!Scalar::from(0i64).is_positive());
assert!(!Scalar::from(0i64).is_negative());
}
#[test]
fn from_i64_roundtrip() {
let s = Scalar::from(42i64);
assert_eq!(s.try_as_rat(), Some(Rat::from(42i64)));
}
#[test]
fn equality_across_types() {
let a = Scalar::from(3i64);
let b = Scalar::Radical(RadicalElement::from_rat(Rat::from(3)));
assert_eq!(a, b);
}
#[test]
fn display() {
assert_eq!(format!("{}", Scalar::from(7i64)), "7");
}
#[test]
fn overflow_add_produces_big() {
let big = Scalar::Rat(Rat::new(i128::MAX / 2 + 1, 1));
let result = big.clone() + big;
assert!(result.is_big(), "expected Big, got {:?}", result);
assert!(!result.is_zero());
assert!(result.is_positive());
}
#[test]
fn overflow_mul_produces_big() {
let big = Scalar::Rat(Rat::new(i128::MAX / 2, 1));
let result = big * Scalar::from(3i64);
assert!(result.is_big(), "expected Big, got {:?}", result);
}
#[test]
fn big_demotes_when_small() {
let a = Scalar::Rat(Rat::new(i64::MAX as i128, 1));
let b = Scalar::from(1i64);
let big = a.clone() + b.clone();
let result = big - b;
assert!(matches!(result, Scalar::Rat(_)));
}
#[test]
fn big_add_big() {
let a = Scalar::Big(BigRat::from_i128_pair(i128::MAX, 1));
let b = Scalar::Big(BigRat::from_i128_pair(i128::MAX, 1));
let c = a + b;
assert!(c.is_big());
assert!(c.is_positive());
}
#[test]
fn big_mul_rat() {
let a = Scalar::Big(BigRat::from_i128_pair(i128::MAX, 1));
let b = Scalar::from(2i64);
let c = a * b;
assert!(c.is_big());
}
#[test]
fn big_cancellation_demotes() {
let a = Scalar::Big(BigRat::from_i128(42));
let b = Scalar::Big(BigRat::from_i128(42));
let c = a - b;
assert!(c.is_zero());
assert!(matches!(c, Scalar::Rat(_)));
}
#[test]
fn normal_add_stays_rat() {
let a = Scalar::from(3i64) + Scalar::from(4i64);
assert!(matches!(a, Scalar::Rat(_)));
assert_eq!(a.try_as_rat(), Some(Rat::from(7i64)));
}
#[test]
fn neg_big() {
let a = Scalar::Big(BigRat::from_i128(5));
let neg = -a;
assert!(neg.is_negative());
}
#[test]
fn scalar_ord_rat_rat() {
assert!(Scalar::from(3i64) < Scalar::from(5i64));
assert!(Scalar::from(-1i64) < Scalar::from(0i64));
assert!(Scalar::from(7i64) > Scalar::from(7i64 - 1));
}
#[test]
fn scalar_ord_rat_radical() {
let sqrt2 = Scalar::Radical(RadicalElement::sqrt(Rat::from(2)));
assert!(Scalar::from(1i64) < sqrt2);
assert!(sqrt2 < Scalar::from(2i64));
}
#[test]
fn scalar_ord_transitivity() {
let a = Scalar::from(-5i64);
let b = Scalar::from(0i64);
let c = Scalar::from(5i64);
assert!(a < b);
assert!(b < c);
assert!(a < c);
}
#[test]
fn scalar_ord_equality() {
let a = Scalar::from(7i64);
let b = Scalar::from(7i64);
assert!(a == b);
assert!(!(a < b));
assert!(!(a > b));
}
#[test]
fn scalar_ord_big_vs_rat() {
let big = Scalar::Big(BigRat::from_i128(42));
let rat = Scalar::from(42i64);
assert!(big == rat);
}
}