use crate::Problem;
use crate::structural::{RationalFacts, RationalStorageClass};
use num::bigint::Sign::{self, *};
use num::{BigInt, BigUint, ToPrimitive};
use num::{One, Zero};
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize};
use std::cmp::Ordering;
use std::sync::LazyLock;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct Rational {
sign: Sign,
numerator: BigUint,
denominator: BigUint,
}
#[cfg(feature = "serde")]
impl<'de> Deserialize<'de> for Rational {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct RationalWire {
sign: Sign,
numerator: BigUint,
denominator: BigUint,
}
let wire = RationalWire::deserialize(deserializer)?;
if wire.denominator.is_zero() {
return Err(serde::de::Error::custom(
"Rational denominator must be nonzero",
));
}
Ok(Self::from_fraction_parts(wire.sign, wire.numerator, wire.denominator).reduce())
}
}
static ONE: LazyLock<BigUint> = LazyLock::new(BigUint::one);
static TWO: LazyLock<BigUint> = LazyLock::new(|| BigUint::from(2_u8));
static FIVE: LazyLock<BigUint> = LazyLock::new(|| BigUint::from(5_u8));
static TEN: LazyLock<BigUint> = LazyLock::new(|| BigUint::from(10_u8));
macro_rules! trace_rational_temporary {
() => {{
#[cfg(feature = "dispatch-trace")]
crate::dispatch_trace::record_rational_temporary();
}};
}
macro_rules! trace_rational_reduction {
($numerator:expr, $denominator:expr) => {{
#[cfg(feature = "dispatch-trace")]
crate::dispatch_trace::record_rational_reduction($numerator, $denominator);
}};
}
macro_rules! trace_rational_gcd {
($left:expr, $right:expr, $divisor:expr) => {{
#[cfg(feature = "dispatch-trace")]
crate::dispatch_trace::record_rational_gcd($left, $right, $divisor);
}};
}
macro_rules! trace_rational_power_of_two_common_factor {
($shift:expr) => {{
#[cfg(feature = "dispatch-trace")]
crate::dispatch_trace::record_rational_power_of_two_common_factor($shift);
}};
}
impl Rational {
pub fn zero() -> Self {
trace_rational_temporary!();
Self {
sign: NoSign,
numerator: BigUint::ZERO,
denominator: BigUint::one(),
}
}
pub fn one() -> Self {
trace_rational_temporary!();
Self {
sign: Plus,
numerator: BigUint::one(),
denominator: BigUint::one(),
}
}
pub fn new(n: i64) -> Self {
Self::from_integer_magnitude(
if n < 0 { Minus } else { Plus },
BigUint::from(n.unsigned_abs()),
)
}
pub fn from_bigint(n: BigInt) -> Self {
Self::from_bigint_fraction(n, BigUint::one()).unwrap()
}
pub fn fraction(n: i64, d: u64) -> Result<Self, Problem> {
if d == 0 {
return Err(Problem::DivideByZero);
}
let sign = if n < 0 { Minus } else { Plus };
let numerator = BigUint::from(n.unsigned_abs());
let denominator = BigUint::from(d);
Ok(Self::from_fraction_parts(sign, numerator, denominator).reduce())
}
pub fn from_bigint_fraction(n: BigInt, denominator: BigUint) -> Result<Self, Problem> {
if denominator == BigUint::ZERO {
return Err(Problem::DivideByZero);
}
let (sign, numerator) = n.into_parts();
let answer = Self::from_fraction_parts(sign, numerator, denominator);
Ok(answer.reduce())
}
pub(crate) fn from_integer_magnitude(sign: Sign, numerator: BigUint) -> Self {
Self::from_fraction_parts(sign, numerator, BigUint::one())
}
pub(crate) fn from_unsigned_integer(numerator: BigUint) -> Self {
Self::from_integer_magnitude(Plus, numerator)
}
fn from_fraction_parts(sign: Sign, numerator: BigUint, denominator: BigUint) -> Self {
if sign == NoSign || numerator.is_zero() {
return Self::zero();
}
trace_rational_temporary!();
Self {
sign,
numerator,
denominator,
}
}
pub(crate) fn add_one(&self) -> Self {
if self.sign == NoSign {
return Self::one();
}
match self.sign {
Plus => Self::from_fraction_parts(
Plus,
&self.numerator + &self.denominator,
self.denominator.clone(),
),
Minus => match self.numerator.cmp(&self.denominator) {
Ordering::Greater => Self::from_fraction_parts(
Minus,
&self.numerator - &self.denominator,
self.denominator.clone(),
),
Ordering::Equal => Self::zero(),
Ordering::Less => Self::from_fraction_parts(
Plus,
&self.denominator - &self.numerator,
self.denominator.clone(),
),
},
NoSign => unreachable!(),
}
}
pub(crate) fn subtract_one(&self) -> Self {
if self.sign == NoSign {
return Self::from_integer_magnitude(Minus, ONE.deref().clone());
}
match self.sign {
Plus => match self.numerator.cmp(&self.denominator) {
Ordering::Greater => Self::from_fraction_parts(
Plus,
&self.numerator - &self.denominator,
self.denominator.clone(),
),
Ordering::Equal => Self::zero(),
Ordering::Less => Self::from_fraction_parts(
Minus,
&self.denominator - &self.numerator,
self.denominator.clone(),
),
},
Minus => Self::from_fraction_parts(
Minus,
&self.numerator + &self.denominator,
self.denominator.clone(),
),
NoSign => unreachable!(),
}
}
fn maybe_reduce(self) -> Self {
if Self::is_power_of_two(&self.denominator) {
let denominator = self.denominator.clone();
trace_rational_reduction!(&self.numerator, &self.denominator);
self.reduce_by_power_of_two_divisor(&denominator)
} else {
self.reduce()
}
}
fn reduce_with_possible_divisor(self, possible_divisor: &BigUint) -> Self {
if self.sign == NoSign || self.numerator.is_zero() {
return Self::zero();
}
if self.denominator == *ONE.deref() || possible_divisor == &*ONE {
return self;
}
trace_rational_reduction!(&self.numerator, &self.denominator);
if Self::is_power_of_two(possible_divisor) {
return self.reduce_by_power_of_two_divisor(possible_divisor);
}
let divisor = num::Integer::gcd(&self.numerator, possible_divisor);
trace_rational_gcd!(&self.numerator, possible_divisor, &divisor);
if divisor == *ONE.deref() {
self
} else {
trace_rational_temporary!();
Self {
sign: self.sign,
numerator: self.numerator / &divisor,
denominator: self.denominator / divisor,
}
}
}
fn reduce(self) -> Self {
if self.denominator == *ONE.deref() {
return self;
}
trace_rational_reduction!(&self.numerator, &self.denominator);
if Self::is_power_of_two(&self.denominator) {
let denominator = self.denominator.clone();
return self.reduce_by_power_of_two_divisor(&denominator);
}
let divisor = num::Integer::gcd(&self.numerator, &self.denominator);
trace_rational_gcd!(&self.numerator, &self.denominator, &divisor);
if divisor == *ONE.deref() {
self
} else {
let numerator = self.numerator / &divisor;
let denominator = self.denominator / &divisor;
trace_rational_temporary!();
Self {
sign: self.sign,
numerator,
denominator,
}
}
}
fn biguint_power_of_two_shift(value: &BigUint) -> Option<u64> {
if value.is_zero() {
return None;
}
let shift = value
.trailing_zeros()
.expect("non-zero BigUint has trailing zeros");
(shift == value.bits() - 1).then_some(shift)
}
fn is_power_of_two(value: &BigUint) -> bool {
Self::biguint_power_of_two_shift(value).is_some()
}
fn reduce_by_power_of_two_divisor(self, possible_divisor: &BigUint) -> Self {
if self.sign == NoSign || self.numerator.is_zero() {
return Self::zero();
}
let numerator_shift = self
.numerator
.trailing_zeros()
.expect("non-zero numerator has trailing zeros");
if numerator_shift == 0 {
trace_rational_power_of_two_common_factor!(0);
return self;
}
let divisor_shift = possible_divisor
.trailing_zeros()
.expect("power-of-two divisor has trailing zeros");
let shift = numerator_shift.min(divisor_shift);
if shift == 0 {
trace_rational_power_of_two_common_factor!(0);
return self;
}
let shift = usize::try_from(shift).expect("shift should fit in usize");
trace_rational_power_of_two_common_factor!(shift as u64);
trace_rational_temporary!();
Self {
sign: self.sign,
numerator: self.numerator >> shift,
denominator: self.denominator >> shift,
}
}
pub fn inverse(self) -> Result<Self, Problem> {
if self.numerator == BigUint::ZERO {
return Err(Problem::DivideByZero);
}
Ok(Self {
sign: self.sign,
numerator: self.denominator,
denominator: self.numerator,
})
}
pub fn is_integer(&self) -> bool {
self.denominator == *ONE.deref()
}
pub fn is_dyadic(&self) -> bool {
Self::is_power_of_two(&self.denominator)
}
#[inline]
pub fn same_denominator(&self, other: &Self) -> bool {
self.denominator == other.denominator
}
pub(crate) fn dyadic_denominator_shift(&self) -> Option<u64> {
Self::biguint_power_of_two_shift(&self.denominator)
}
fn from_signed_magnitude_difference(
positive: BigUint,
negative: BigUint,
denominator: BigUint,
) -> Self {
let (sign, numerator) = match positive.cmp(&negative) {
Ordering::Greater => (Plus, positive - negative),
Ordering::Less => (Minus, negative - positive),
Ordering::Equal => return Self::zero(),
};
trace_rational_temporary!();
Self {
sign,
numerator,
denominator,
}
.maybe_reduce()
}
fn dot_products_dyadic<const N: usize>(
left: [&Self; N],
right: [&Self; N],
signs: [Sign; N],
) -> Option<Self> {
let mut max_shift = 0_u64;
let mut denominator_shifts = [0_u64; N];
let mut any_nonzero = false;
for i in 0..N {
if signs[i] == NoSign {
continue;
}
let shift =
left[i].dyadic_denominator_shift()? + right[i].dyadic_denominator_shift()?;
denominator_shifts[i] = shift;
max_shift = max_shift.max(shift);
any_nonzero = true;
}
if !any_nonzero {
return Some(Self::zero());
}
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..N {
let sign = signs[i];
if sign == NoSign {
continue;
}
let scale_shift = usize::try_from(max_shift - denominator_shifts[i])
.expect("dyadic dot-product scale should fit in usize");
let mut magnitude = &left[i].numerator * &right[i].numerator;
if scale_shift != 0 {
magnitude <<= scale_shift;
}
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
let denominator =
BigUint::one() << usize::try_from(max_shift).expect("shift should fit in usize");
Some(Self::from_signed_magnitude_difference(
positive,
negative,
denominator,
))
}
fn dot_products_equal_denominator<const N: usize>(
left: [&Self; N],
right: [&Self; N],
signs: [Sign; N],
) -> Option<Self> {
let mut shared_denominator = None::<BigUint>;
for i in 0..N {
if signs[i] == NoSign {
continue;
}
let denominator = &left[i].denominator * &right[i].denominator;
match &shared_denominator {
None => shared_denominator = Some(denominator),
Some(shared) if *shared == denominator => {}
Some(_) => return None,
}
}
let Some(denominator) = shared_denominator else {
return Some(Self::zero());
};
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..N {
let sign = signs[i];
if sign == NoSign {
continue;
}
let magnitude = &left[i].numerator * &right[i].numerator;
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
Some(Self::from_signed_magnitude_difference(
positive,
negative,
denominator,
))
}
fn product_term_denominator<const FACTORS: usize>(term: [&Self; FACTORS]) -> BigUint {
let mut denominator = BigUint::one();
for factor in term {
denominator *= &factor.denominator;
}
denominator
}
fn product_term_magnitude<const FACTORS: usize>(term: [&Self; FACTORS]) -> BigUint {
let mut magnitude = BigUint::one();
for factor in term {
magnitude *= &factor.numerator;
}
magnitude
}
fn product_term_sign<const FACTORS: usize>(positive: bool, term: [&Self; FACTORS]) -> Sign {
let mut sign = if positive { Plus } else { Minus };
for factor in term {
sign = sign * factor.sign;
}
sign
}
fn signed_product_sum_dyadic<const TERMS: usize, const FACTORS: usize>(
terms: [[&Self; FACTORS]; TERMS],
signs: [Sign; TERMS],
) -> Option<Self> {
let mut max_shift = 0_u64;
let mut denominator_shifts = [0_u64; TERMS];
let mut any_nonzero = false;
for i in 0..TERMS {
if signs[i] == NoSign {
continue;
}
let mut shift = 0_u64;
for factor in terms[i] {
shift += factor.dyadic_denominator_shift()?;
}
denominator_shifts[i] = shift;
max_shift = max_shift.max(shift);
any_nonzero = true;
}
if !any_nonzero {
return Some(Self::zero());
}
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..TERMS {
let sign = signs[i];
if sign == NoSign {
continue;
}
let scale_shift = usize::try_from(max_shift - denominator_shifts[i])
.expect("dyadic product-sum scale should fit in usize");
let mut magnitude = Self::product_term_magnitude(terms[i]);
if scale_shift != 0 {
magnitude <<= scale_shift;
}
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
let denominator =
BigUint::one() << usize::try_from(max_shift).expect("shift should fit in usize");
Some(Self::from_signed_magnitude_difference(
positive,
negative,
denominator,
))
}
pub fn signed_product_sum_shared_denominator<const TERMS: usize, const FACTORS: usize>(
positive_terms: [bool; TERMS],
terms: [[&Self; FACTORS]; TERMS],
) -> Option<Self> {
debug_assert!(FACTORS > 0);
let mut signs = [NoSign; TERMS];
let mut nonzero_count = 0_usize;
let mut shared_denominator = None::<&BigUint>;
for i in 0..TERMS {
let sign = Self::product_term_sign(positive_terms[i], terms[i]);
if sign == NoSign {
signs[i] = sign;
continue;
}
nonzero_count += 1;
signs[i] = sign;
for factor in terms[i] {
match shared_denominator {
None => shared_denominator = Some(&factor.denominator),
Some(shared) if shared == &factor.denominator => {}
Some(_) => return None,
}
}
}
if nonzero_count == 0 {
crate::trace_dispatch!(
"rational",
"product_sum",
"shared-factor-denominator-all-zero"
);
return Some(Self::zero());
}
let exponent = u32::try_from(FACTORS).ok()?;
let denominator = shared_denominator
.expect("nonzero product sum has a live factor denominator")
.pow(exponent);
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..TERMS {
let sign = signs[i];
if sign == NoSign {
continue;
}
let magnitude = Self::product_term_magnitude(terms[i]);
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
crate::trace_dispatch!("rational", "product_sum", "shared-factor-denominator");
Some(Self::from_signed_magnitude_difference(
positive,
negative,
denominator,
))
}
pub fn signed_product_sum<const TERMS: usize, const FACTORS: usize>(
positive_terms: [bool; TERMS],
terms: [[&Self; FACTORS]; TERMS],
) -> Self {
debug_assert!(FACTORS > 0);
let mut signs = [NoSign; TERMS];
let mut nonzero_count = 0_usize;
for i in 0..TERMS {
let sign = Self::product_term_sign(positive_terms[i], terms[i]);
if sign != NoSign {
nonzero_count += 1;
}
signs[i] = sign;
}
if nonzero_count == 0 {
crate::trace_dispatch!("rational", "product_sum", "all-zero");
return Self::zero();
}
if nonzero_count == 1 {
for i in 0..TERMS {
match signs[i] {
Plus => {
crate::trace_dispatch!("rational", "product_sum", "single-term-product");
let denominator = Self::product_term_denominator(terms[i]);
return Self::from_signed_magnitude_difference(
Self::product_term_magnitude(terms[i]),
BigUint::ZERO,
denominator,
);
}
Minus => {
crate::trace_dispatch!("rational", "product_sum", "single-term-product");
let denominator = Self::product_term_denominator(terms[i]);
return Self::from_signed_magnitude_difference(
BigUint::ZERO,
Self::product_term_magnitude(terms[i]),
denominator,
);
}
NoSign => {}
}
}
}
if let Some(dyadic) = Self::signed_product_sum_dyadic(terms, signs) {
crate::trace_dispatch!("rational", "product_sum", "dyadic-shared-denominator");
return dyadic;
}
let mut denominators: [BigUint; TERMS] = std::array::from_fn(|_| BigUint::ZERO);
let mut shared_denominator = None::<BigUint>;
let mut equal_denominator = true;
for i in 0..TERMS {
if signs[i] == NoSign {
continue;
}
let denominator = Self::product_term_denominator(terms[i]);
match &shared_denominator {
None => shared_denominator = Some(denominator.clone()),
Some(shared) if *shared == denominator => {}
Some(_) => equal_denominator = false,
}
denominators[i] = denominator;
}
if equal_denominator {
let denominator = shared_denominator.expect("nonzero product sum has denominator");
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..TERMS {
let sign = signs[i];
if sign == NoSign {
continue;
}
let magnitude = Self::product_term_magnitude(terms[i]);
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
crate::trace_dispatch!("rational", "product_sum", "equal-product-denominator");
return Self::from_signed_magnitude_difference(positive, negative, denominator);
}
crate::trace_dispatch!("rational", "product_sum", "lcm-shared-denominator");
let mut common_denominator = BigUint::one();
for i in 0..TERMS {
if signs[i] == NoSign {
continue;
}
let denominator = &denominators[i];
if denominator != ONE.deref() {
let divisor = num::Integer::gcd(&common_denominator, denominator);
trace_rational_gcd!(&common_denominator, denominator, &divisor);
common_denominator *= denominator / &divisor;
}
}
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..TERMS {
let sign = signs[i];
if sign == NoSign {
continue;
}
let mut magnitude = Self::product_term_magnitude(terms[i]);
let denominator = &denominators[i];
if denominator != &common_denominator {
magnitude *= &common_denominator / denominator;
}
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
Self::from_signed_magnitude_difference(positive, negative, common_denominator)
}
pub(crate) fn dot_products<const N: usize>(left: [&Self; N], right: [&Self; N]) -> Self {
let mut signs = [NoSign; N];
let mut nonzero_count = 0_usize;
for i in 0..N {
let sign = left[i].sign * right[i].sign;
if sign != NoSign {
nonzero_count += 1;
}
signs[i] = sign;
}
if nonzero_count == 0 {
crate::trace_dispatch!("rational", "dot_product", "all-zero");
return Self::zero();
}
if nonzero_count == 1 {
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..N {
match signs[i] {
Plus => {
let denominator = &left[i].denominator * &right[i].denominator;
positive = &left[i].numerator * &right[i].numerator;
crate::trace_dispatch!("rational", "dot_product", "single-term-product");
return Self::from_signed_magnitude_difference(
positive,
negative,
denominator,
);
}
Minus => {
let denominator = &left[i].denominator * &right[i].denominator;
negative = &left[i].numerator * &right[i].numerator;
crate::trace_dispatch!("rational", "dot_product", "single-term-product");
return Self::from_signed_magnitude_difference(
positive,
negative,
denominator,
);
}
NoSign => {}
}
}
return Self::zero();
}
if let Some(dyadic) = Self::dot_products_dyadic(left, right, signs) {
crate::trace_dispatch!("rational", "dot_product", "dyadic-shared-denominator");
return dyadic;
}
if let Some(equal_denominator) = Self::dot_products_equal_denominator(left, right, signs) {
crate::trace_dispatch!("rational", "dot_product", "equal-product-denominator");
return equal_denominator;
}
crate::trace_dispatch!("rational", "dot_product", "lcm-shared-denominator");
let mut common_denominator = BigUint::one();
let mut any_nonzero = false;
for i in 0..N {
if signs[i] == NoSign {
continue;
}
let denominator = &left[i].denominator * &right[i].denominator;
if denominator != *ONE.deref() {
let divisor = num::Integer::gcd(&common_denominator, &denominator);
trace_rational_gcd!(&common_denominator, &denominator, &divisor);
common_denominator *= denominator / &divisor;
}
any_nonzero = true;
}
if !any_nonzero {
return Self::zero();
}
let mut positive = BigUint::ZERO;
let mut negative = BigUint::ZERO;
for i in 0..N {
let sign = signs[i];
if sign == NoSign {
continue;
}
let denominator = &left[i].denominator * &right[i].denominator;
let mut magnitude = &left[i].numerator * &right[i].numerator;
if denominator != common_denominator {
magnitude *= &common_denominator / denominator;
}
match sign {
Plus => positive += magnitude,
Minus => negative += magnitude,
NoSign => {}
}
}
Self::from_signed_magnitude_difference(positive, negative, common_denominator)
}
#[inline(always)]
pub fn is_zero(&self) -> bool {
self.sign == NoSign
}
#[inline(always)]
pub fn is_one(&self) -> bool {
self.sign == Plus && self.numerator == *ONE.deref() && self.denominator == *ONE.deref()
}
#[inline]
pub(crate) fn is_two(&self) -> bool {
self.sign == Plus
&& self.numerator.bits() == 2
&& self.numerator == *TWO.deref()
&& self.denominator == *ONE.deref()
}
#[inline]
pub(crate) fn is_one_half(&self) -> bool {
self.sign == Plus
&& self.numerator.bits() == 1
&& self.denominator.bits() == 2
&& self.numerator == *ONE.deref()
&& self.denominator == *TWO.deref()
}
#[inline]
pub(crate) fn is_minus_one(&self) -> bool {
self.sign == Minus && self.numerator == *ONE.deref() && self.denominator == *ONE.deref()
}
#[inline]
pub(crate) fn cmp_one_structural(&self) -> Ordering {
match self.sign {
Minus | NoSign => Ordering::Less,
Plus => self.numerator.cmp(&self.denominator),
}
}
#[inline]
pub(crate) fn abs_cmp_one_structural(&self) -> Ordering {
self.numerator.cmp(&self.denominator)
}
pub fn trunc(&self) -> Self {
if self.is_integer() {
return self.clone();
}
let n = &self.numerator / &self.denominator;
Self {
sign: self.sign,
numerator: n,
denominator: ONE.deref().clone(),
}
}
pub fn fract(&self) -> Self {
if self.is_integer() {
return Self::zero();
}
let n = &self.numerator % &self.denominator;
Self {
sign: self.sign,
numerator: n,
denominator: self.denominator.clone(),
}
}
pub(crate) fn denominator(&self) -> &BigUint {
&self.denominator
}
pub(crate) fn numerator(&self) -> &BigUint {
&self.numerator
}
pub(crate) fn factor_two_powers(&self) -> (i32, Self) {
let numerator_shift = self.numerator.trailing_zeros().unwrap_or(0);
let denominator_shift = self
.denominator
.trailing_zeros()
.expect("Rational denominators are never zero");
let shift = i32::try_from(numerator_shift).expect("shift should fit in i32")
- i32::try_from(denominator_shift).expect("shift should fit in i32");
let numerator =
&self.numerator >> usize::try_from(numerator_shift).expect("shift should fit in usize");
let denominator = &self.denominator
>> usize::try_from(denominator_shift).expect("shift should fit in usize");
(
shift,
Self {
sign: self.sign,
numerator,
denominator,
},
)
}
pub(crate) fn divide_by_power_of_two(&self, shift: i32) -> Option<Self> {
if shift < 0 {
return None;
}
if self.sign == NoSign || self.numerator.is_zero() {
return Some(Self::zero());
}
let shift = u64::try_from(shift).ok()?;
let numerator_shift = self
.numerator
.trailing_zeros()
.expect("non-zero numerator has trailing zeros")
.min(shift);
let denominator_shift = shift - numerator_shift;
let numerator =
&self.numerator >> usize::try_from(numerator_shift).expect("shift should fit");
let denominator =
&self.denominator << usize::try_from(denominator_shift).expect("shift should fit");
Some(Self::from_fraction_parts(self.sign, numerator, denominator))
}
#[inline]
pub(crate) fn power_of_two_shift(&self) -> Option<(i32, Sign)> {
if self.sign == NoSign {
return None;
}
let numerator_shift = self
.numerator
.trailing_zeros()
.expect("Rational numerators are never zero for non-zero signs");
if numerator_shift != self.numerator.bits() - 1 {
return None;
}
let denominator_shift = self
.denominator
.trailing_zeros()
.expect("Rational denominators are never zero");
if denominator_shift != self.denominator.bits() - 1 {
return None;
}
let numerator_shift = i32::try_from(numerator_shift).ok()?;
let denominator_shift = i32::try_from(denominator_shift).ok()?;
Some((numerator_shift - denominator_shift, self.sign))
}
pub(crate) fn magnitude_at_least_power_of_two(&self, exponent: u32) -> bool {
if self.sign == NoSign {
return false;
}
let numerator_bits = self.numerator.bits();
let target_bits = self.denominator.bits() + u64::from(exponent);
if numerator_bits > target_bits {
return true;
}
if numerator_bits < target_bits {
return false;
}
self.numerator >= (&self.denominator << exponent as usize)
}
pub(crate) fn msd_exact(&self) -> Option<i32> {
if self.sign == NoSign {
return None;
}
let numerator_bits = self.numerator.bits() as i32;
let denominator_bits = self.denominator.bits() as i32;
let candidate = numerator_bits - denominator_bits;
let below = if candidate >= 0 {
self.numerator < (&self.denominator << candidate as usize)
} else {
(&self.numerator << (-candidate) as usize) < self.denominator
};
if below {
Some(candidate - 1)
} else {
Some(candidate)
}
}
pub(crate) fn to_f64_lossy(&self) -> Option<f64> {
if self.sign == NoSign {
return Some(0.0);
}
let msd = self.msd_exact()?;
if msd > 1023 {
return None;
}
if msd < -1075 {
return Some(0.0);
}
let numerator = self.numerator.to_f64()?;
let denominator = self.denominator.to_f64()?;
let value = numerator / denominator;
if !value.is_finite() {
return None;
}
Some(match self.sign {
Minus => -value,
NoSign => 0.0,
Plus => value,
})
}
pub fn prefer_fraction(&self) -> bool {
let mut rem = self.denominator.clone();
while (&rem % &*TEN).is_zero() {
rem /= &*TEN;
}
while (&rem % &*FIVE).is_zero() {
rem /= &*FIVE;
}
while (&rem % &*TWO).is_zero() {
rem /= &*TWO;
}
rem != BigUint::one()
}
pub fn shifted_big_integer(&self, shift: i32) -> BigInt {
let whole = (&self.numerator << shift) / &self.denominator;
BigInt::from_biguint(self.sign, whole)
}
#[inline]
pub(crate) fn compare_magnitude(&self, other: &Self) -> Ordering {
if self.denominator == other.denominator {
return self.numerator.cmp(&other.numerator);
}
if self.numerator.bits() > 64 && self.numerator == other.numerator {
return other.denominator.cmp(&self.denominator);
}
(&self.numerator * &other.denominator).cmp(&(&other.numerator * &self.denominator))
}
#[inline]
pub(crate) fn compare_magnitude_squared_times(&self, factor: &Self, other: &Self) -> Ordering {
let left = &self.numerator * &self.numerator * &factor.numerator * &other.denominator;
let right = &self.denominator * &self.denominator * &factor.denominator * &other.numerator;
left.cmp(&right)
}
#[inline]
pub(crate) fn detailed_rational_facts(&self) -> RationalFacts {
let denominator_is_one = self.denominator == *ONE.deref();
let numerator_bits = self.numerator.bits();
let denominator_bits = self.denominator.bits();
let denominator_is_power_of_two = Self::is_power_of_two(&self.denominator);
let numerator_is_power_of_two =
self.sign != NoSign && Self::is_power_of_two(&self.numerator);
let storage = if self.sign == NoSign {
RationalStorageClass::Zero
} else if numerator_bits <= 64 && denominator_bits <= 64 {
RationalStorageClass::WordSized
} else if numerator_bits.saturating_add(denominator_bits) <= 4096 {
RationalStorageClass::MultiLimb
} else {
RationalStorageClass::VeryLarge
};
RationalFacts {
exact_integer: denominator_is_one,
exact_small_integer_i64: denominator_is_one
&& (self.sign == NoSign || numerator_bits <= 63),
exact_dyadic: denominator_is_power_of_two,
power_of_two: numerator_is_power_of_two && denominator_is_power_of_two,
storage,
}
}
pub fn to_big_integer(&self) -> Option<BigInt> {
self.integer_magnitude()
.map(|magnitude| BigInt::from_biguint(self.sign, magnitude.clone()))
}
pub(crate) fn integer_magnitude(&self) -> Option<&BigUint> {
(self.denominator == *ONE.deref()).then_some(&self.numerator)
}
pub(crate) fn to_integer_i64(&self) -> Option<i64> {
let magnitude = self.integer_magnitude()?.to_i64()?;
match self.sign {
Plus | NoSign => Some(magnitude),
Minus => magnitude.checked_neg(),
}
}
#[inline]
pub fn sign(&self) -> Sign {
self.sign
}
const EXTRACT_SQUARE_MAX_LEN: u64 = 5000;
fn make_squares() -> Vec<(BigUint, BigUint)> {
vec![
(BigUint::from(2_u8), BigUint::from(4_u8)),
(BigUint::from(3_u8), BigUint::from(9_u8)),
(BigUint::from(5_u8), BigUint::from(25_u8)),
(BigUint::from(7_u8), BigUint::from(49_u8)),
(BigUint::from(11_u8), BigUint::from(121_u8)),
(BigUint::from(13_u8), BigUint::from(169_u16)),
(BigUint::from(17_u8), BigUint::from(289_u16)),
]
}
fn try_perfect(n: BigUint) -> Option<BigUint> {
let root = n.sqrt();
let square = &root * &root;
if square == n { Some(root) } else { None }
}
fn extract_square(n: BigUint) -> (BigUint, BigUint) {
static SQUARES: LazyLock<Vec<(BigUint, BigUint)>> = LazyLock::new(Rational::make_squares);
let one: BigUint = One::one();
let mut root = one.clone();
let mut rest = n;
if rest.bits() > Self::EXTRACT_SQUARE_MAX_LEN {
return (root, rest);
}
for (p, s) in &*SQUARES {
if rest == one {
break;
}
while (&rest % s).is_zero() {
rest /= s;
root *= p;
}
}
let divisors = if rest.bit(0) {
[1_u8, 3, 5, 7, 11, 13, 15, 17, 19]
} else {
[1_u8, 2, 3, 5, 6, 7, 8, 10, 11]
};
for n in divisors {
let divisor = BigUint::from(n);
if rest == divisor {
return (root, rest);
}
if (&rest % &divisor).is_zero() {
let square = &rest / &divisor;
if let Some(factor) = Self::try_perfect(square) {
return (root * factor, divisor);
}
}
}
(root, rest)
}
pub fn extract_square_reduced(self) -> (Self, Self) {
if self.sign == NoSign {
return (Self::zero(), Self::zero());
}
let (nroot, nrest) = Self::extract_square(self.numerator);
let (droot, drest) = Self::extract_square(self.denominator);
(
Self {
sign: Plus,
numerator: nroot,
denominator: droot,
},
Self {
sign: self.sign,
numerator: nrest,
denominator: drest,
},
)
}
pub fn extract_square_will_succeed(&self) -> bool {
self.numerator.bits() < Self::EXTRACT_SQUARE_MAX_LEN
&& self.denominator.bits() < Self::EXTRACT_SQUARE_MAX_LEN
}
fn pow_up(&self, exp: &BigUint) -> Self {
if exp == &BigUint::ZERO {
return Self::one();
}
if let Some(exp) = exp.to_u32().filter(|exp| *exp <= 64) {
return Self {
numerator: self.numerator.pow(exp),
denominator: self.denominator.pow(exp),
sign: if self.sign == Minus && exp % 2 == 1 {
Minus
} else {
Plus
},
};
}
let mut result = Self::one();
let mut factor = self.clone();
let bits = exp.bits();
for b in 0..bits {
if exp.bit(b) {
result *= &factor;
}
if b + 1 < bits {
factor = &factor * &factor;
}
}
result
}
pub fn powi(self, exp: BigInt) -> Result<Self, Problem> {
const TOO_MANY_BITS: u64 = 1000;
if exp == BigInt::ZERO {
return Ok(Self::one());
}
if self.sign == NoSign {
return Ok(Self::zero());
}
if self.is_integer() && self.numerator == *ONE.deref() {
if self.sign == Minus && exp.bit(0) {
return Ok(Self::new(-1));
} else {
return Ok(Self::one());
}
}
if exp.bits() >= TOO_MANY_BITS {
return Err(Problem::Exhausted);
}
match exp.sign() {
Minus => Ok(self.inverse()?.pow_up(exp.magnitude())),
Plus => Ok(self.pow_up(exp.magnitude())),
NoSign => unreachable!(),
}
}
}
impl AsRef<Rational> for Rational {
fn as_ref(&self) -> &Rational {
self
}
}
use core::fmt;
impl fmt::Display for Rational {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.denominator == *ONE.deref() {
let int = self.numerator.to_string();
return f.pad_integral(self.sign != Minus, "", &int);
}
if self.sign == Minus {
f.write_str("-")?;
} else if f.sign_plus() {
f.write_str("+")?;
}
if f.alternate() {
let whole = &self.numerator / &self.denominator;
write!(f, "{whole}.")?;
let round = &whole * &self.denominator;
let mut left = &self.numerator - &round;
let mut digits = f.precision().unwrap_or(1000);
loop {
left *= &*TEN;
let digit = &left / &self.denominator;
write!(f, "{digit}")?;
left -= digit * &self.denominator;
if left.is_zero() {
break;
}
digits -= 1;
if digits == 0 {
break;
}
}
Ok(())
} else {
let whole = &self.numerator / &self.denominator;
let round = &whole * &self.denominator;
let left = &self.numerator - &round;
if whole.is_zero() {
write!(f, "{left}/{}", self.denominator)
} else {
write!(f, "{whole} {left}/{}", self.denominator)
}
}
}
}
impl std::str::FromStr for Rational {
type Err = Problem;
fn from_str(s: &str) -> Result<Self, Problem> {
let mut sign: Sign = Plus;
let s = match s.strip_prefix('-') {
Some(s) => {
sign = Minus;
s
}
None => s,
};
if let Some((n, d)) = s.split_once('/') {
let numerator = BigUint::parse_bytes(n.as_bytes(), 10).ok_or(Problem::BadFraction)?;
if numerator.is_zero() {
sign = NoSign;
}
let denominator = BigUint::parse_bytes(d.as_bytes(), 10).ok_or(Problem::BadFraction)?;
if denominator.is_zero() {
return Err(Problem::DivideByZero);
}
Ok(Self::from_fraction_parts(sign, numerator, denominator).reduce())
} else if let Some((i, d)) = s.split_once('.') {
let numerator = BigUint::parse_bytes(i.as_bytes(), 10).ok_or(Problem::BadDecimal)?;
let whole = if numerator.is_zero() {
Self {
sign: NoSign,
numerator,
denominator: One::one(),
}
} else {
Self {
sign,
numerator,
denominator: One::one(),
}
};
let numerator = BigUint::parse_bytes(d.as_bytes(), 10).ok_or(Problem::BadDecimal)?;
if numerator.is_zero() {
return Ok(whole);
}
let denominator = TEN.pow(d.len() as u32);
let fraction = Self {
sign,
numerator,
denominator,
};
Ok(whole + fraction)
} else {
let numerator = BigUint::parse_bytes(s.as_bytes(), 10).ok_or(Problem::BadInteger)?;
if numerator.is_zero() {
sign = NoSign;
}
Ok(Self {
sign,
numerator,
denominator: One::one(),
})
}
}
}
use core::ops::*;
impl<T: AsRef<Rational>> Add<T> for &Rational {
type Output = Rational;
fn add(self, other: T) -> Self::Output {
use std::cmp::Ordering::*;
let other = other.as_ref();
if self.sign == NoSign {
return other.clone();
}
if other.sign == NoSign {
return self.clone();
}
if self.is_one() {
return other.add_one();
}
if other.is_one() {
return self.add_one();
}
let common_denominator = num::Integer::gcd(&self.denominator, &other.denominator);
trace_rational_gcd!(&self.denominator, &other.denominator, &common_denominator);
let left_scale = &other.denominator / &common_denominator;
let right_scale = &self.denominator / &common_denominator;
let denominator = &self.denominator * &left_scale;
let a = &self.numerator * &left_scale;
let b = &other.numerator * &right_scale;
let (sign, numerator) = match (self.sign, other.sign) {
(Plus, Plus) => (Plus, a + b),
(Minus, Minus) => (Minus, a + b),
(x, y) => match a.cmp(&b) {
Greater => (x, a - b),
Equal => {
return Self::Output::zero();
}
Less => (y, b - a),
},
};
trace_rational_temporary!();
Self::Output {
sign,
numerator,
denominator,
}
.reduce_with_possible_divisor(&common_denominator)
}
}
impl<T: AsRef<Rational>> Add<T> for Rational {
type Output = Self;
fn add(self, other: T) -> Self {
&self + other.as_ref()
}
}
impl Neg for &Rational {
type Output = Rational;
fn neg(self) -> Self::Output {
trace_rational_temporary!();
let mut ret = self.clone();
ret.sign = -ret.sign;
ret
}
}
impl Neg for Rational {
type Output = Self;
fn neg(mut self) -> Self {
self.sign = -self.sign;
self
}
}
impl<T: AsRef<Rational>> Sub<T> for &Rational {
type Output = Rational;
fn sub(self, other: T) -> Self::Output {
use std::cmp::Ordering::*;
let other = other.as_ref();
if other.sign == NoSign {
return self.clone();
}
if self.sign == NoSign {
return -other;
}
if other.is_one() {
return self.subtract_one();
}
if self.is_one() {
return -other.subtract_one();
}
let common_denominator = num::Integer::gcd(&self.denominator, &other.denominator);
trace_rational_gcd!(&self.denominator, &other.denominator, &common_denominator);
let left_scale = &other.denominator / &common_denominator;
let right_scale = &self.denominator / &common_denominator;
let denominator = &self.denominator * &left_scale;
let a = &self.numerator * &left_scale;
let b = &other.numerator * &right_scale;
let (sign, numerator) = match (self.sign, other.sign) {
(Plus, Minus) => (Plus, a + b),
(Minus, Plus) => (Minus, a + b),
(x, y) => match a.cmp(&b) {
Greater => (x, a - b),
Equal => {
return Self::Output::zero();
}
Less => (-y, b - a),
},
};
trace_rational_temporary!();
Self::Output {
sign,
numerator,
denominator,
}
.reduce_with_possible_divisor(&common_denominator)
}
}
impl<T: AsRef<Rational>> Sub<T> for Rational {
type Output = Self;
fn sub(self, other: T) -> Self {
&self - other.as_ref()
}
}
impl<T: AsRef<Rational>> Mul<T> for &Rational {
type Output = Rational;
fn mul(self, other: T) -> Self::Output {
let other = other.as_ref();
let sign = self.sign * other.sign;
if sign == NoSign {
return Self::Output::zero();
}
let numerator = &self.numerator * &other.numerator;
let denominator = &self.denominator * &other.denominator;
trace_rational_temporary!();
Self::Output::maybe_reduce(Self::Output {
sign,
numerator,
denominator,
})
}
}
impl<T: AsRef<Rational>> Mul<T> for Rational {
type Output = Self;
fn mul(self, other: T) -> Self {
&self * other.as_ref()
}
}
impl<T: AsRef<Rational>> MulAssign<T> for Rational {
fn mul_assign(&mut self, other: T) {
*self = &*self * other.as_ref();
}
}
impl<T: AsRef<Rational>> Div<T> for &Rational {
type Output = Rational;
fn div(self, other: T) -> Self::Output {
let other = other.as_ref();
assert_ne!(other.numerator, BigUint::ZERO);
let sign = self.sign * other.sign;
if sign == NoSign {
return Self::Output::zero();
}
if self.numerator == other.denominator && self.denominator == other.numerator {
trace_rational_temporary!();
return Self::Output {
sign,
numerator: &self.numerator * &self.numerator,
denominator: &self.denominator * &self.denominator,
};
}
let numerator = &self.numerator * &other.denominator;
let denominator = &self.denominator * &other.numerator;
trace_rational_temporary!();
Self::Output::maybe_reduce(Self::Output {
sign,
numerator,
denominator,
})
}
}
impl<T: AsRef<Rational>> Div<T> for Rational {
type Output = Self;
fn div(self, other: T) -> Self {
&self / other.as_ref()
}
}
impl Rational {
fn definitely_equal(&self, other: &Self) -> bool {
if self.sign != other.sign {
return false;
}
if self.denominator != other.denominator {
return false;
}
self.numerator == other.numerator
}
}
impl PartialEq for Rational {
fn eq(&self, other: &Self) -> bool {
if self.sign != other.sign {
return false;
}
if self.denominator == other.denominator {
self.numerator == other.numerator
} else {
Self::definitely_equal(&self.clone().reduce(), &other.clone().reduce())
}
}
}
impl PartialOrd for Rational {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
use std::cmp::Ordering::*;
match self.sign.cmp(&other.sign) {
Less => return Some(Less),
Greater => return Some(Greater),
Equal => {
if self.sign == NoSign {
return Some(Equal);
}
}
}
if self.denominator == other.denominator {
match self.sign {
Plus => self.numerator.partial_cmp(&other.numerator),
Minus => other.numerator.partial_cmp(&self.numerator),
NoSign => unreachable!(),
}
} else {
let left = &self.numerator * &other.denominator;
let right = &other.numerator * &self.denominator;
match self.sign {
Plus => left.partial_cmp(&right),
Minus => right.partial_cmp(&left),
NoSign => unreachable!(),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn display() {
let many: Rational = "12345".parse().unwrap();
let s = format!("{many}");
assert_eq!(s, "12345");
let five: Rational = "5".parse().unwrap();
let third: Rational = "1/3".parse().unwrap();
let s = format!("{}", five * third);
assert_eq!(s, "1 2/3");
}
#[test]
fn decimals() {
let first: Rational = "0.0".parse().unwrap();
assert_eq!(first, Rational::zero());
let a: Rational = "0.4".parse().unwrap();
let b: Rational = "2.5".parse().unwrap();
let answer = a * b;
assert_eq!(answer, Rational::one());
}
#[test]
fn parse() {
let big: Rational = "288230376151711743".parse().unwrap();
let small: Rational = "45".parse().unwrap();
let expected: Rational = "12970366926827028435".parse().unwrap();
assert_eq!(big * small, expected);
}
#[test]
fn parse_fractions() {
let third: Rational = "1/3".parse().unwrap();
let minus_four: Rational = "-4".parse().unwrap();
let twelve: Rational = "12/20".parse().unwrap();
let answer = third + minus_four * twelve;
let expected: Rational = "-31/15".parse().unwrap();
assert_eq!(answer, expected);
}
#[test]
fn parse_fraction_rejects_zero_denominator_and_reduces() {
assert_eq!("1/0".parse::<Rational>(), Err(Problem::DivideByZero));
assert_eq!("0/0".parse::<Rational>(), Err(Problem::DivideByZero));
let reduced: Rational = "9/18".parse().unwrap();
assert_eq!(reduced, Rational::fraction(1, 2).unwrap());
assert_eq!(format!("{reduced}"), "1/2");
}
#[cfg(feature = "serde")]
#[test]
fn serde_rejects_invalid_or_uncanonical_rational_state() {
let bad = r#"{"sign":1,"numerator":[1],"denominator":[]}"#;
assert!(serde_json::from_str::<Rational>(bad).is_err());
let unreduced = r#"{"sign":1,"numerator":[9],"denominator":[18]}"#;
let decoded: Rational = serde_json::from_str(unreduced).unwrap();
assert_eq!(decoded, Rational::fraction(1, 2).unwrap());
assert_eq!(format!("{decoded}"), "1/2");
}
#[test]
fn square_reduced() {
let thirty_two = Rational::new(32);
let (square, rest) = thirty_two.extract_square_reduced();
let four = Rational::new(4);
assert_eq!(square, four);
let two = Rational::new(2);
assert_eq!(rest, two);
let minus_one = Rational::new(-1);
let (square, rest) = minus_one.clone().extract_square_reduced();
assert_eq!(square, Rational::one());
assert_eq!(rest, minus_one);
}
#[test]
fn signs() {
let half: Rational = "4/8".parse().unwrap();
let one = Rational::one();
let minus_half = half - one;
let two = Rational::new(2);
let zero = Rational::zero();
let minus_two = zero - two;
let i2 = minus_two.inverse().unwrap();
assert_eq!(i2, minus_half);
}
#[test]
fn half_plus_one_times_two() {
let two = Rational::new(2);
let half = two.inverse().unwrap();
let one = Rational::one();
let two = Rational::new(2);
let three = Rational::new(3);
let sum = half + one;
assert_eq!(sum * two, three);
}
#[test]
fn three_divided_by_six() {
let three = Rational::new(3);
let six = Rational::new(6);
let half: Rational = "1/2".parse().unwrap();
assert_eq!(three / six, half);
}
#[test]
fn one_plus_two() {
let one = Rational::one();
let two = Rational::new(2);
let three = Rational::new(3);
assert_eq!(one + two, three);
}
#[test]
fn two_minus_one() {
let two = Rational::new(2);
let one = Rational::one();
assert_eq!(two - one, Rational::one());
}
#[test]
fn two_times_three() {
let two = Rational::new(2);
let three = Rational::new(3);
assert_eq!(two * three, Rational::new(6));
}
#[test]
fn fract() {
let seventy_ninths = Rational::fraction(70, 9).unwrap();
assert_eq!(seventy_ninths.fract(), Rational::fraction(7, 9).unwrap());
assert_eq!(
seventy_ninths.neg().fract(),
Rational::fraction(-7, 9).unwrap()
);
let six = Rational::new(6);
assert_eq!(six.fract(), Rational::zero());
}
#[test]
fn trunc() {
let seventy_ninths = Rational::fraction(70, 9).unwrap();
let whole = seventy_ninths.trunc();
let frac = seventy_ninths.fract();
assert_eq!(whole + frac, seventy_ninths);
let shrink = Rational::fraction(-405, 11).unwrap();
let whole = shrink.trunc();
let frac = shrink.fract();
assert_eq!(whole + frac, shrink);
let zero = Rational::zero();
let whole = zero.trunc();
let frac = zero.fract();
assert_eq!(whole, frac);
assert_eq!(whole + frac, zero);
}
#[test]
fn power() {
let one_two_five = Rational::new(5).powi(BigInt::from(-3));
assert_eq!(one_two_five, Rational::fraction(1, 125));
let more = Rational::new(7).powi(11i32.into()).unwrap();
assert_eq!(more, Rational::new(1_977_326_743));
}
#[test]
fn sqrt_trouble() {
for (n, root, rest) in [
(1, 1, 1),
(2, 1, 2),
(3, 1, 3),
(4, 2, 1),
(16, 4, 1),
(400, 20, 1),
(1323, 21, 3),
(4761, 69, 1),
(123456, 8, 1929),
(715716, 846, 1),
] {
let n = Rational::new(n);
let reduced = n.extract_square_reduced();
assert_eq!(reduced, (Rational::new(root), Rational::new(rest)));
}
}
#[test]
fn decimal() {
let decimal: Rational = "7.125".parse().unwrap();
assert!(!decimal.prefer_fraction());
let half: Rational = "4/8".parse().unwrap();
assert!(!half.prefer_fraction());
let third: Rational = "2/6".parse().unwrap();
assert!(third.prefer_fraction());
}
#[test]
fn power_of_two_shift_detects_only_power_of_two_ratios() {
assert_eq!(
Rational::fraction(8, 1).unwrap().power_of_two_shift(),
Some((3, Plus))
);
assert_eq!(
Rational::fraction(1, 8).unwrap().power_of_two_shift(),
Some((-3, Plus))
);
assert_eq!(
Rational::fraction(-4, 32).unwrap().power_of_two_shift(),
Some((-3, Minus))
);
assert_eq!(Rational::fraction(7, 8).unwrap().power_of_two_shift(), None);
assert_eq!(Rational::fraction(5, 6).unwrap().power_of_two_shift(), None);
assert_eq!(Rational::zero().power_of_two_shift(), None);
}
#[test]
fn add_and_subtract_one_helpers_match_generic_arithmetic() {
for value in [
Rational::zero(),
Rational::one(),
Rational::new(-1),
Rational::fraction(7, 4).unwrap(),
Rational::fraction(3, 5).unwrap(),
Rational::fraction(-7, 4).unwrap(),
Rational::fraction(-3, 5).unwrap(),
] {
assert_eq!(value.add_one(), value.clone() + Rational::one());
assert_eq!(value.subtract_one(), value.clone() - Rational::one());
}
}
#[test]
fn magnitude_at_least_power_of_two_handles_threshold_boundaries() {
assert!(
!Rational::fraction(7, 1)
.unwrap()
.magnitude_at_least_power_of_two(3)
);
assert!(
Rational::fraction(8, 1)
.unwrap()
.magnitude_at_least_power_of_two(3)
);
assert!(
Rational::fraction(-9, 1)
.unwrap()
.magnitude_at_least_power_of_two(3)
);
assert!(
!Rational::fraction(15, 2)
.unwrap()
.magnitude_at_least_power_of_two(3)
);
assert!(
Rational::fraction(16, 2)
.unwrap()
.magnitude_at_least_power_of_two(3)
);
assert!(!Rational::zero().magnitude_at_least_power_of_two(3));
}
#[test]
fn dyadic_add_sub_stay_reduced() {
let three_eighths = Rational::fraction(3, 8).unwrap();
let five_sixteenths = Rational::fraction(5, 16).unwrap();
assert_eq!(
&three_eighths + &five_sixteenths,
Rational::fraction(11, 16).unwrap()
);
assert_eq!(
&three_eighths - &five_sixteenths,
Rational::fraction(1, 16).unwrap()
);
assert_eq!(
&five_sixteenths - &three_eighths,
Rational::fraction(-1, 16).unwrap()
);
assert_eq!(&three_eighths - &three_eighths, Rational::zero());
}
#[test]
fn same_denominator_reports_reduced_common_scale() {
let a = Rational::fraction(3, 10).unwrap();
let b = Rational::fraction(-7, 10).unwrap();
let reduced = Rational::fraction(6, 20).unwrap();
let c = Rational::fraction(1, 3).unwrap();
assert!(a.same_denominator(&b));
assert!(a.same_denominator(&reduced));
assert!(!a.same_denominator(&c));
}
#[test]
fn dot_products_match_pairwise_arithmetic() {
let left = [
Rational::fraction(3, 8).unwrap(),
Rational::fraction(-5, 16).unwrap(),
Rational::zero(),
Rational::fraction(7, 10).unwrap(),
];
let right = [
Rational::fraction(11, 32).unwrap(),
Rational::fraction(13, 64).unwrap(),
Rational::fraction(17, 19).unwrap(),
Rational::fraction(-23, 25).unwrap(),
];
let expected = &(&left[0] * &right[0])
+ &(&left[1] * &right[1])
+ &(&left[2] * &right[2])
+ &(&left[3] * &right[3]);
assert_eq!(
Rational::dot_products(
[&left[0], &left[1], &left[2], &left[3]],
[&right[0], &right[1], &right[2], &right[3]],
),
expected
);
}
#[test]
fn dot_products_preserve_dyadic_exactness() {
let left = [
Rational::fraction(1, 8).unwrap(),
Rational::fraction(3, 16).unwrap(),
Rational::fraction(-5, 32).unwrap(),
];
let right = [
Rational::fraction(7, 4).unwrap(),
Rational::fraction(-11, 8).unwrap(),
Rational::fraction(13, 16).unwrap(),
];
let dot = Rational::dot_products(
[&left[0], &left[1], &left[2]],
[&right[0], &right[1], &right[2]],
);
assert!(dot.is_dyadic());
assert_eq!(
dot,
&(&left[0] * &right[0]) + &(&left[1] * &right[1]) + &(&left[2] * &right[2])
);
}
#[test]
fn dot_products_handle_equal_non_dyadic_denominators() {
let left = [
Rational::fraction(7, 10).unwrap(),
Rational::fraction(-9, 10).unwrap(),
Rational::fraction(11, 10).unwrap(),
];
let right = [
Rational::fraction(13, 7).unwrap(),
Rational::fraction(5, 7).unwrap(),
Rational::fraction(-3, 7).unwrap(),
];
assert_eq!(
Rational::dot_products(
[&left[0], &left[1], &left[2]],
[&right[0], &right[1], &right[2]],
),
&(&left[0] * &right[0]) + &(&left[1] * &right[1]) + &(&left[2] * &right[2])
);
}
#[test]
fn signed_product_sum_matches_pairwise_arithmetic() {
let terms = [
[
Rational::fraction(3, 8).unwrap(),
Rational::fraction(-5, 12).unwrap(),
Rational::fraction(7, 11).unwrap(),
],
[
Rational::fraction(13, 9).unwrap(),
Rational::fraction(17, 25).unwrap(),
Rational::fraction(-19, 6).unwrap(),
],
[
Rational::fraction(-23, 10).unwrap(),
Rational::fraction(29, 14).unwrap(),
Rational::fraction(31, 15).unwrap(),
],
];
let expected = &(&terms[0][0] * &terms[0][1] * &terms[0][2])
- &(&terms[1][0] * &terms[1][1] * &terms[1][2])
+ &(&terms[2][0] * &terms[2][1] * &terms[2][2]);
assert_eq!(
Rational::signed_product_sum(
[true, false, true],
[
[&terms[0][0], &terms[0][1], &terms[0][2]],
[&terms[1][0], &terms[1][1], &terms[1][2]],
[&terms[2][0], &terms[2][1], &terms[2][2]],
],
),
expected
);
}
#[test]
fn signed_product_sum_preserves_dyadic_exactness() {
let terms = [
[
Rational::fraction(1, 8).unwrap(),
Rational::fraction(3, 16).unwrap(),
],
[
Rational::fraction(5, 32).unwrap(),
Rational::fraction(7, 64).unwrap(),
],
[
Rational::fraction(-9, 4).unwrap(),
Rational::fraction(11, 8).unwrap(),
],
];
let sum = Rational::signed_product_sum(
[true, false, true],
[
[&terms[0][0], &terms[0][1]],
[&terms[1][0], &terms[1][1]],
[&terms[2][0], &terms[2][1]],
],
);
assert!(sum.is_dyadic());
assert_eq!(
sum,
&(&terms[0][0] * &terms[0][1]) - &(&terms[1][0] * &terms[1][1])
+ &(&terms[2][0] * &terms[2][1])
);
}
#[test]
fn signed_product_sum_handles_equal_non_dyadic_denominators() {
let terms = [
[
Rational::fraction(7, 10).unwrap(),
Rational::fraction(13, 7).unwrap(),
],
[
Rational::fraction(9, 10).unwrap(),
Rational::fraction(5, 7).unwrap(),
],
[
Rational::fraction(11, 10).unwrap(),
Rational::fraction(3, 7).unwrap(),
],
];
assert_eq!(
Rational::signed_product_sum(
[true, false, true],
[
[&terms[0][0], &terms[0][1]],
[&terms[1][0], &terms[1][1]],
[&terms[2][0], &terms[2][1]],
],
),
&(&terms[0][0] * &terms[0][1]) - &(&terms[1][0] * &terms[1][1])
+ &(&terms[2][0] * &terms[2][1])
);
}
#[test]
fn signed_product_sum_shared_denominator_consumes_common_scale() {
let terms = [
[
Rational::fraction(7, 15).unwrap(),
Rational::fraction(13, 15).unwrap(),
],
[
Rational::fraction(8, 15).unwrap(),
Rational::fraction(-2, 15).unwrap(),
],
[
Rational::fraction(11, 15).unwrap(),
Rational::fraction(14, 15).unwrap(),
],
];
let expected = &(&terms[0][0] * &terms[0][1]) - &(&terms[1][0] * &terms[1][1])
+ &(&terms[2][0] * &terms[2][1]);
assert_eq!(
Rational::signed_product_sum_shared_denominator(
[true, false, true],
[
[&terms[0][0], &terms[0][1]],
[&terms[1][0], &terms[1][1]],
[&terms[2][0], &terms[2][1]],
],
),
Some(expected)
);
let mixed = Rational::fraction(1, 7).unwrap();
assert_eq!(
Rational::signed_product_sum_shared_denominator(
[true, false, true],
[
[&terms[0][0], &terms[0][1]],
[&terms[1][0], &mixed],
[&terms[2][0], &terms[2][1]],
],
),
None
);
}
#[test]
fn compare() {
assert!(Rational::one() > Rational::zero());
assert!(Rational::new(5) > Rational::new(4));
assert!(Rational::new(-10) < Rational::new(5));
assert!(Rational::fraction(1, 4).unwrap() < Rational::fraction(1, 3).unwrap());
}
#[test]
fn same() {
use std::cmp::Ordering;
assert_eq!(
Rational::zero().partial_cmp(&Rational::zero()),
Some(Ordering::Equal)
);
assert_eq!(
Rational::one().partial_cmp(&Rational::one()),
Some(Ordering::Equal)
);
assert_eq!(
Rational::new(-10).partial_cmp(&Rational::new(-10)),
Some(Ordering::Equal)
);
}
#[test]
fn divide_by_zero() {
let err = Rational::fraction(1, 0).unwrap_err();
assert_eq!(err, Problem::DivideByZero);
let zero = Rational::zero();
let err = zero.inverse().unwrap_err();
assert_eq!(err, Problem::DivideByZero);
}
#[test]
fn operations_work_on_refs_on_rhs() {
let a = Rational::new(2);
let b = Rational::new(3);
let c = Rational::new(6);
assert_eq!(a.clone() * &b, c.clone());
assert_eq!(c.clone() / &b, a.clone());
assert_eq!(c.clone() - &a, Rational::new(4));
assert_eq!(-&c, Rational::new(-6));
assert_eq!(a.clone() + &b, Rational::new(5));
}
#[test]
fn operations_work_on_refs() {
let a = Rational::new(2);
let b = Rational::new(3);
let c = Rational::new(6);
assert_eq!(&a * &b, c.clone());
assert_eq!(&c / &b, a.clone());
assert_eq!(&c - &a, Rational::new(4));
assert_eq!(-&c, Rational::new(-6));
assert_eq!(&a + &b, Rational::new(5));
}
}