use std::collections::BTreeMap;
use std::fmt;
use thiserror::Error;
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct Rational {
num: i32,
den: i32,
}
impl fmt::Debug for Rational {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(self, f)
}
}
impl fmt::Display for Rational {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.den == 1 {
write!(f, "{}", self.num)
} else {
write!(f, "{}/{}", self.num, self.den)
}
}
}
impl Rational {
pub const ZERO: Self = Self { num: 0, den: 1 };
pub const ONE: Self = Self { num: 1, den: 1 };
pub const HALF: Self = Self { num: 1, den: 2 };
pub const THIRD: Self = Self { num: 1, den: 3 };
pub fn try_new(num: i32, den: i32) -> Result<Self, RationalError> {
if den == 0 {
return Err(RationalError::ZeroDenominator);
}
if num == 0 {
return Ok(Self::ZERO);
}
let g = gcd(num.unsigned_abs(), den.unsigned_abs()).cast_signed();
let (n, d) = (num / g, den / g);
if d < 0 {
Ok(Self { num: -n, den: -d })
} else {
Ok(Self { num: n, den: d })
}
}
#[must_use]
pub const fn from_int(n: i32) -> Self {
if n == 0 {
Self::ZERO
} else {
Self { num: n, den: 1 }
}
}
#[must_use]
pub const fn num(self) -> i32 {
self.num
}
#[must_use]
pub const fn den(self) -> i32 {
self.den
}
#[must_use]
pub const fn is_zero(self) -> bool {
self.num == 0
}
#[must_use]
pub const fn is_integer(self) -> bool {
self.den == 1
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum RationalError {
#[error("denominator must not be zero")]
ZeroDenominator,
#[error("dimension exponent overflowed i32")]
Overflow,
}
fn reduce_i64(num: i64, den: i64) -> Result<(i32, i32), RationalError> {
if den == 0 {
return Err(RationalError::ZeroDenominator);
}
if num == 0 {
return Ok((0, 1));
}
let g = gcd64(num.unsigned_abs(), den.unsigned_abs()).cast_signed();
let (mut n, mut d) = (num / g, den / g);
if d < 0 {
n = -n;
d = -d;
}
let num = i32::try_from(n).map_err(|_| RationalError::Overflow)?;
let den = i32::try_from(d).map_err(|_| RationalError::Overflow)?;
Ok((num, den))
}
impl std::ops::Add for Rational {
type Output = Result<Self, RationalError>;
fn add(self, rhs: Self) -> Self::Output {
let num =
i64::from(self.num) * i64::from(rhs.den) + i64::from(rhs.num) * i64::from(self.den);
let den = i64::from(self.den) * i64::from(rhs.den);
let (n, d) = reduce_i64(num, den)?;
Ok(Self { num: n, den: d })
}
}
impl std::ops::Sub for Rational {
type Output = Result<Self, RationalError>;
fn sub(self, rhs: Self) -> Self::Output {
let num =
i64::from(self.num) * i64::from(rhs.den) - i64::from(rhs.num) * i64::from(self.den);
let den = i64::from(self.den) * i64::from(rhs.den);
let (n, d) = reduce_i64(num, den)?;
Ok(Self { num: n, den: d })
}
}
impl std::ops::Neg for Rational {
type Output = Self;
fn neg(self) -> Self {
Self {
num: -self.num,
den: self.den,
}
}
}
impl std::ops::Mul for Rational {
type Output = Result<Self, RationalError>;
fn mul(self, rhs: Self) -> Self::Output {
let num = i64::from(self.num) * i64::from(rhs.num);
let den = i64::from(self.den) * i64::from(rhs.den);
let (n, d) = reduce_i64(num, den)?;
Ok(Self { num: n, den: d })
}
}
fn gcd(a: u32, b: u32) -> u32 {
if b == 0 { a } else { gcd(b, a % b) }
}
fn gcd64(a: u64, b: u64) -> u64 {
if b == 0 { a } else { gcd64(b, a % b) }
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum BaseDimId {
Prelude(String),
UserDefined {
dag: crate::dag_id::DagId,
name: String,
},
}
impl BaseDimId {
#[must_use]
pub fn fallback_symbol(&self) -> String {
match self {
Self::Prelude(name) | Self::UserDefined { name, .. } => name.clone(),
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Dimension {
exponents: BTreeMap<BaseDimId, Rational>,
}
impl fmt::Debug for Dimension {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_dimensionless() {
write!(f, "Dimension(Dimensionless)")
} else {
write!(f, "Dimension(")?;
let mut first = true;
for (id, exp) in &self.exponents {
if !first {
write!(f, " * ")?;
}
first = false;
match id {
BaseDimId::Prelude(name) | BaseDimId::UserDefined { name, .. } => {
write!(f, "{name}")?;
}
}
if *exp != Rational::ONE {
write!(f, "^{exp}")?;
}
}
write!(f, ")")
}
}
}
impl Dimension {
#[must_use]
pub const fn dimensionless() -> Self {
Self {
exponents: BTreeMap::new(),
}
}
#[must_use]
pub fn base(id: BaseDimId) -> Self {
let mut exponents = BTreeMap::new();
exponents.insert(id, Rational::ONE);
Self { exponents }
}
#[must_use]
pub fn is_dimensionless(&self) -> bool {
self.exponents.is_empty()
}
#[must_use]
pub fn get_exponent(&self, id: &BaseDimId) -> Rational {
self.exponents.get(id).copied().unwrap_or(Rational::ZERO)
}
pub fn iter(&self) -> impl Iterator<Item = (&BaseDimId, &Rational)> {
self.exponents.iter()
}
pub fn pow(&self, exp: Rational) -> Result<Self, RationalError> {
if exp.is_zero() {
return Ok(Self::dimensionless());
}
let mut exponents = BTreeMap::new();
for (id, &e) in &self.exponents {
let new_exp = (e * exp)?;
if !new_exp.is_zero() {
exponents.insert(id.clone(), new_exp);
}
}
Ok(Self { exponents })
}
pub fn pow_int(&self, n: i32) -> Result<Self, RationalError> {
self.pow(Rational::from_int(n))
}
#[must_use]
pub const fn display_with<'a>(
&'a self,
names: &'a BTreeMap<BaseDimId, String>,
) -> DimensionDisplay<'a> {
DimensionDisplay { dim: self, names }
}
fn write_exponents(
&self,
w: &mut impl fmt::Write,
names: &BTreeMap<BaseDimId, String>,
mul_sep: &str,
div_sep: &str,
) -> fmt::Result {
let mut first = true;
for (id, &exp) in &self.exponents {
if exp.num() <= 0 {
continue;
}
if !first {
w.write_str(mul_sep)?;
}
first = false;
let name = names
.get(id)
.map_or_else(|| id.fallback_symbol(), String::clone);
write!(w, "{name}")?;
if exp != Rational::ONE {
write!(w, "^{exp}")?;
}
}
for (id, &exp) in &self.exponents {
if exp.num() >= 0 {
continue;
}
let name = names
.get(id)
.map_or_else(|| id.fallback_symbol(), String::clone);
if first {
write!(w, "{name}^{exp}")?;
first = false;
} else {
w.write_str(div_sep)?;
write!(w, "{name}")?;
let pos_exp = -exp;
if pos_exp != Rational::ONE {
write!(w, "^{pos_exp}")?;
}
}
}
Ok(())
}
}
pub struct DimensionDisplay<'a> {
dim: &'a Dimension,
names: &'a BTreeMap<BaseDimId, String>,
}
impl fmt::Display for DimensionDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.dim.is_dimensionless() {
return write!(f, "Dimensionless");
}
self.dim.write_exponents(f, self.names, " * ", " / ")
}
}
#[derive(Clone, Copy)]
enum CombineOp {
Add,
Sub,
}
impl Dimension {
fn combine(self, other: &Self, op: CombineOp) -> Result<Self, RationalError> {
let mut exponents = self.exponents;
for (id, exp) in &other.exponents {
let entry = exponents.entry(id.clone()).or_insert(Rational::ZERO);
*entry = match op {
CombineOp::Add => (*entry + *exp)?,
CombineOp::Sub => (*entry - *exp)?,
};
if entry.is_zero() {
exponents.remove(id);
}
}
Ok(Self { exponents })
}
}
impl std::ops::Mul for Dimension {
type Output = Result<Self, RationalError>;
fn mul(self, other: Self) -> Self::Output {
self.combine(&other, CombineOp::Add)
}
}
impl std::ops::Div for Dimension {
type Output = Result<Self, RationalError>;
fn div(self, other: Self) -> Self::Output {
self.combine(&other, CombineOp::Sub)
}
}
impl std::ops::Mul for &Dimension {
type Output = Result<Dimension, RationalError>;
fn mul(self, other: Self) -> Self::Output {
self.clone().combine(other, CombineOp::Add)
}
}
impl std::ops::Div for &Dimension {
type Output = Result<Dimension, RationalError>;
fn div(self, other: Self) -> Self::Output {
self.clone().combine(other, CombineOp::Sub)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn r(num: i32, den: i32) -> Rational {
Rational::try_new(num, den).expect("non-zero denominator")
}
fn length() -> BaseDimId {
BaseDimId::Prelude("Length".to_string())
}
fn time() -> BaseDimId {
BaseDimId::Prelude("Time".to_string())
}
fn mass() -> BaseDimId {
BaseDimId::Prelude("Mass".to_string())
}
fn test_names() -> BTreeMap<BaseDimId, String> {
let mut m = BTreeMap::new();
m.insert(
BaseDimId::Prelude("Length".to_string()),
"Length".to_string(),
);
m.insert(BaseDimId::Prelude("Time".to_string()), "Time".to_string());
m.insert(BaseDimId::Prelude("Mass".to_string()), "Mass".to_string());
m.insert(
BaseDimId::Prelude("Temperature".to_string()),
"Temperature".to_string(),
);
m.insert(
BaseDimId::Prelude("ElectricCurrent".to_string()),
"ElectricCurrent".to_string(),
);
m.insert(
BaseDimId::Prelude("Amount".to_string()),
"Amount".to_string(),
);
m.insert(
BaseDimId::Prelude("LuminousIntensity".to_string()),
"LuminousIntensity".to_string(),
);
m.insert(BaseDimId::Prelude("Angle".to_string()), "Angle".to_string());
m
}
#[test]
fn rational_creation_and_reduction() {
assert_eq!(r(2, 4), r(1, 2));
assert_eq!(r(-3, 6), r(-1, 2));
assert_eq!(r(6, -4), r(-3, 2));
assert_eq!(r(0, 5), Rational::ZERO);
}
#[test]
fn rational_arithmetic() {
let half = r(1, 2);
let third = r(1, 3);
let sum = (half + third).unwrap();
assert_eq!(sum, r(5, 6));
let diff = (half - third).unwrap();
assert_eq!(diff, r(1, 6));
let prod = (half * third).unwrap();
assert_eq!(prod, r(1, 6));
assert_eq!(-half, r(-1, 2));
}
#[test]
fn rational_from_int() {
assert_eq!(Rational::from_int(3), r(3, 1));
assert_eq!(Rational::from_int(0), Rational::ZERO);
assert_eq!(Rational::from_int(-2), r(-2, 1));
}
#[test]
fn dimension_base() {
let len = Dimension::base(length());
assert_eq!(len.get_exponent(&length()), Rational::ONE);
assert!(len.get_exponent(&time()).is_zero());
assert!(len.get_exponent(&mass()).is_zero());
}
#[test]
fn dimension_dimensionless() {
assert!(Dimension::dimensionless().is_dimensionless());
assert!(!Dimension::base(length()).is_dimensionless());
}
#[test]
fn dimension_velocity() {
let l = Dimension::base(length());
let t = Dimension::base(time());
let velocity = (l / t).unwrap();
assert_eq!(velocity.get_exponent(&length()), Rational::ONE);
assert_eq!(velocity.get_exponent(&time()), Rational::from_int(-1));
}
#[test]
fn dimension_acceleration() {
let l = Dimension::base(length());
let t = Dimension::base(time());
let accel = (l / t.pow_int(2).unwrap()).unwrap();
assert_eq!(accel.get_exponent(&length()), Rational::ONE);
assert_eq!(accel.get_exponent(&time()), Rational::from_int(-2));
}
#[test]
fn dimension_force() {
let m = Dimension::base(mass());
let l = Dimension::base(length());
let t = Dimension::base(time());
let force = ((m * l).unwrap() / t.pow_int(2).unwrap()).unwrap();
assert_eq!(force.get_exponent(&mass()), Rational::ONE);
assert_eq!(force.get_exponent(&length()), Rational::ONE);
assert_eq!(force.get_exponent(&time()), Rational::from_int(-2));
}
#[test]
fn dimension_sqrt() {
let area = Dimension::base(length()).pow_int(2).unwrap();
let sqrt_area = area.pow(Rational::HALF).unwrap();
assert_eq!(sqrt_area, Dimension::base(length()));
}
#[test]
fn dimension_mul_div_inverse() {
let l = Dimension::base(length());
let t = Dimension::base(time());
let velocity = (l.clone() / t.clone()).unwrap();
assert_eq!((velocity.clone() * t.clone()).unwrap(), l);
assert_eq!((l / velocity).unwrap(), t);
}
#[test]
fn dimension_dimensionless_mul() {
let l = Dimension::base(length());
assert_eq!((Dimension::dimensionless() * l.clone()).unwrap(), l);
assert_eq!((l.clone() * Dimension::dimensionless()).unwrap(), l);
}
#[test]
fn dimension_display_simple() {
let names = test_names();
assert_eq!(
format!("{}", Dimension::dimensionless().display_with(&names)),
"Dimensionless"
);
assert_eq!(
format!("{}", Dimension::base(length()).display_with(&names)),
"Length"
);
}
#[test]
fn dimension_display_velocity() {
let names = test_names();
let velocity = (Dimension::base(length()) / Dimension::base(time())).unwrap();
assert_eq!(
format!("{}", velocity.display_with(&names)),
"Length / Time"
);
}
#[test]
fn dimension_display_force() {
let names = test_names();
let force = ((Dimension::base(mass()) * Dimension::base(length())).unwrap()
/ Dimension::base(time()).pow_int(2).unwrap())
.unwrap();
assert_eq!(
format!("{}", force.display_with(&names)),
"Length * Mass / Time^2"
);
}
#[test]
fn dimension_display_area() {
let names = test_names();
let area = Dimension::base(length()).pow_int(2).unwrap();
assert_eq!(format!("{}", area.display_with(&names)), "Length^2");
}
#[test]
fn dimension_display_frequency() {
let names = test_names();
let freq = (Dimension::dimensionless() / Dimension::base(time())).unwrap();
assert_eq!(format!("{}", freq.display_with(&names)), "Time^-1");
}
#[test]
fn dimension_user_defined_base() {
let info_id = BaseDimId::UserDefined {
dag: crate::dag_id::DagId::root("test"),
name: "Information".to_string(),
};
let information = Dimension::base(info_id.clone());
let t = Dimension::base(time());
let bandwidth = (information / t).unwrap();
assert_eq!(bandwidth.get_exponent(&info_id), Rational::ONE);
assert_eq!(bandwidth.get_exponent(&time()), Rational::from_int(-1));
let mut names = test_names();
names.insert(info_id, "Information".to_string());
assert_eq!(
format!("{}", bandwidth.display_with(&names)),
"Information / Time"
);
}
#[test]
fn dimension_hash_consistency() {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let a = (Dimension::base(length()) / Dimension::base(time())).unwrap();
let b = (Dimension::base(length()) / Dimension::base(time())).unwrap();
assert_eq!(a, b);
let mut ha = DefaultHasher::new();
a.hash(&mut ha);
let mut hb = DefaultHasher::new();
b.hash(&mut hb);
assert_eq!(ha.finish(), hb.finish());
}
mod prop {
use super::*;
use proptest::prelude::*;
fn arb_rational() -> impl Strategy<Value = Rational> {
(-50i32..=50, -50i32..=50)
.prop_filter("denominator must be non-zero", |&(_, d)| d != 0)
.prop_map(|(n, d)| Rational::try_new(n, d).expect("filtered d != 0"))
}
const PRELUDE_DIMS: [&str; 8] = [
"Length",
"Time",
"Mass",
"Temperature",
"ElectricCurrent",
"Amount",
"LuminousIntensity",
"Angle",
];
fn arb_dimension() -> impl Strategy<Value = Dimension> {
proptest::collection::btree_map(0usize..8, arb_rational(), 0..=8).prop_map(|map| {
let exponents = map
.into_iter()
.filter(|(_, r)| !r.is_zero())
.map(|(idx, r)| (BaseDimId::Prelude(PRELUDE_DIMS[idx].to_string()), r))
.collect();
Dimension { exponents }
})
}
proptest! {
#[test]
fn rational_always_reduced(n in -100i32..=100, d in -100i32..=100) {
prop_assume!(d != 0);
let r = Rational::try_new(n, d).expect("d != 0 by prop_assume");
prop_assert!(r.den() > 0, "den must be positive, got {}", r.den());
if r.num() != 0 {
let g = gcd(r.num().unsigned_abs(), r.den().unsigned_abs());
prop_assert_eq!(g, 1, "not reduced: {}/{}", r.num(), r.den());
} else {
prop_assert_eq!(r.den(), 1, "zero should have den=1, got {}", r.den());
}
}
#[test]
fn rational_add_commutative(a in arb_rational(), b in arb_rational()) {
prop_assert_eq!((a + b).unwrap(), (b + a).unwrap());
}
#[test]
fn rational_mul_commutative(a in arb_rational(), b in arb_rational()) {
prop_assert_eq!((a * b).unwrap(), (b * a).unwrap());
}
#[test]
fn rational_additive_identity(a in arb_rational()) {
prop_assert_eq!((a + Rational::ZERO).unwrap(), a);
}
#[test]
fn rational_multiplicative_identity(a in arb_rational()) {
prop_assert_eq!((a * Rational::ONE).unwrap(), a);
}
#[test]
fn rational_additive_inverse(a in arb_rational()) {
prop_assert_eq!((a + (-a)).unwrap(), Rational::ZERO);
}
#[test]
fn rational_sub_self_is_zero(a in arb_rational()) {
prop_assert_eq!((a - a).unwrap(), Rational::ZERO);
}
#[test]
fn dimension_mul_commutative(a in arb_dimension(), b in arb_dimension()) {
prop_assert_eq!((a.clone() * b.clone()).unwrap(), (b * a).unwrap());
}
#[test]
fn dimension_dimensionless_is_mul_identity(a in arb_dimension()) {
prop_assert_eq!((a.clone() * Dimension::dimensionless()).unwrap(), a);
}
#[test]
fn dimension_self_div_is_dimensionless(a in arb_dimension()) {
prop_assert_eq!((a.clone() / a).unwrap(), Dimension::dimensionless());
}
#[test]
fn dimension_div_inverse(a in arb_dimension(), b in arb_dimension()) {
prop_assert_eq!(((a.clone() / b.clone()).unwrap() * b).unwrap(), a);
}
#[test]
fn dimension_pow_int_consistent_with_pow(a in arb_dimension(), n in -3i32..=3) {
prop_assert_eq!(a.pow_int(n).unwrap(), a.pow(Rational::from_int(n)).unwrap());
}
#[test]
fn dimension_pow_distributes_over_mul(
a in arb_dimension(),
b in arb_dimension(),
r in arb_rational(),
) {
prop_assert_eq!(
(a.clone() * b.clone()).unwrap().pow(r).unwrap(),
(a.pow(r).unwrap() * b.pow(r).unwrap()).unwrap(),
);
}
}
}
}