1use std::collections::BTreeMap;
2use std::fmt;
3
4use thiserror::Error;
5
6#[derive(Clone, Copy, PartialEq, Eq, Hash)]
10pub struct Rational {
11 num: i32,
12 den: i32,
13}
14
15impl fmt::Debug for Rational {
16 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17 fmt::Display::fmt(self, f)
18 }
19}
20
21impl fmt::Display for Rational {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 if self.den == 1 {
24 write!(f, "{}", self.num)
25 } else {
26 write!(f, "{}/{}", self.num, self.den)
27 }
28 }
29}
30
31impl Rational {
32 pub const ZERO: Self = Self { num: 0, den: 1 };
33 pub const ONE: Self = Self { num: 1, den: 1 };
34 pub const HALF: Self = Self { num: 1, den: 2 };
36 pub const THIRD: Self = Self { num: 1, den: 3 };
38
39 pub fn try_new(num: i32, den: i32) -> Result<Self, RationalError> {
43 if den == 0 {
44 return Err(RationalError::ZeroDenominator);
45 }
46 if num == 0 {
47 return Ok(Self::ZERO);
48 }
49 let g = gcd(num.unsigned_abs(), den.unsigned_abs()).cast_signed();
50 let (n, d) = (num / g, den / g);
51 if d < 0 {
53 Ok(Self { num: -n, den: -d })
54 } else {
55 Ok(Self { num: n, den: d })
56 }
57 }
58
59 #[must_use]
61 pub const fn from_int(n: i32) -> Self {
62 if n == 0 {
63 Self::ZERO
64 } else {
65 Self { num: n, den: 1 }
66 }
67 }
68
69 #[must_use]
71 pub const fn num(self) -> i32 {
72 self.num
73 }
74
75 #[must_use]
77 pub const fn den(self) -> i32 {
78 self.den
79 }
80
81 #[must_use]
82 pub const fn is_zero(self) -> bool {
83 self.num == 0
84 }
85
86 #[must_use]
87 pub const fn is_integer(self) -> bool {
88 self.den == 1
89 }
90}
91
92#[derive(Debug, Clone, PartialEq, Eq, Error)]
94pub enum RationalError {
95 #[error("denominator must not be zero")]
97 ZeroDenominator,
98 #[error("dimension exponent overflowed i32")]
103 Overflow,
104}
105
106fn reduce_i64(num: i64, den: i64) -> Result<(i32, i32), RationalError> {
111 if den == 0 {
112 return Err(RationalError::ZeroDenominator);
113 }
114 if num == 0 {
115 return Ok((0, 1));
116 }
117 let g = gcd64(num.unsigned_abs(), den.unsigned_abs()).cast_signed();
118 let (mut n, mut d) = (num / g, den / g);
119 if d < 0 {
120 n = -n;
121 d = -d;
122 }
123 let num = i32::try_from(n).map_err(|_| RationalError::Overflow)?;
124 let den = i32::try_from(d).map_err(|_| RationalError::Overflow)?;
125 Ok((num, den))
126}
127
128impl std::ops::Add for Rational {
129 type Output = Result<Self, RationalError>;
130 fn add(self, rhs: Self) -> Self::Output {
131 let num =
133 i64::from(self.num) * i64::from(rhs.den) + i64::from(rhs.num) * i64::from(self.den);
134 let den = i64::from(self.den) * i64::from(rhs.den);
135 let (n, d) = reduce_i64(num, den)?;
136 Ok(Self { num: n, den: d })
137 }
138}
139
140impl std::ops::Sub for Rational {
141 type Output = Result<Self, RationalError>;
142 fn sub(self, rhs: Self) -> Self::Output {
143 let num =
144 i64::from(self.num) * i64::from(rhs.den) - i64::from(rhs.num) * i64::from(self.den);
145 let den = i64::from(self.den) * i64::from(rhs.den);
146 let (n, d) = reduce_i64(num, den)?;
147 Ok(Self { num: n, den: d })
148 }
149}
150
151impl std::ops::Neg for Rational {
152 type Output = Self;
153 fn neg(self) -> Self {
154 Self {
155 num: -self.num,
156 den: self.den,
157 }
158 }
159}
160
161impl std::ops::Mul for Rational {
162 type Output = Result<Self, RationalError>;
163 fn mul(self, rhs: Self) -> Self::Output {
164 let num = i64::from(self.num) * i64::from(rhs.num);
165 let den = i64::from(self.den) * i64::from(rhs.den);
166 let (n, d) = reduce_i64(num, den)?;
167 Ok(Self { num: n, den: d })
168 }
169}
170
171fn gcd(a: u32, b: u32) -> u32 {
172 if b == 0 { a } else { gcd(b, a % b) }
173}
174
175fn gcd64(a: u64, b: u64) -> u64 {
176 if b == 0 { a } else { gcd64(b, a % b) }
177}
178
179#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
184pub enum BaseDimId {
185 Prelude(String),
187 UserDefined {
189 dag: crate::dag_id::DagId,
190 name: String,
191 },
192}
193
194impl BaseDimId {
195 #[must_use]
197 pub fn fallback_symbol(&self) -> String {
198 match self {
199 Self::Prelude(name) | Self::UserDefined { name, .. } => name.clone(),
200 }
201 }
202}
203
204#[derive(Clone, PartialEq, Eq, Hash)]
213pub struct Dimension {
214 exponents: BTreeMap<BaseDimId, Rational>,
216}
217
218impl fmt::Debug for Dimension {
219 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
220 if self.is_dimensionless() {
221 write!(f, "Dimension(Dimensionless)")
222 } else {
223 write!(f, "Dimension(")?;
224 let mut first = true;
225 for (id, exp) in &self.exponents {
226 if !first {
227 write!(f, " * ")?;
228 }
229 first = false;
230 match id {
231 BaseDimId::Prelude(name) | BaseDimId::UserDefined { name, .. } => {
232 write!(f, "{name}")?;
233 }
234 }
235 if *exp != Rational::ONE {
236 write!(f, "^{exp}")?;
237 }
238 }
239 write!(f, ")")
240 }
241 }
242}
243
244impl Dimension {
245 #[must_use]
247 pub const fn dimensionless() -> Self {
248 Self {
249 exponents: BTreeMap::new(),
250 }
251 }
252
253 #[must_use]
255 pub fn base(id: BaseDimId) -> Self {
256 let mut exponents = BTreeMap::new();
257 exponents.insert(id, Rational::ONE);
258 Self { exponents }
259 }
260
261 #[must_use]
262 pub fn is_dimensionless(&self) -> bool {
263 self.exponents.is_empty()
264 }
265
266 #[must_use]
268 pub fn get_exponent(&self, id: &BaseDimId) -> Rational {
269 self.exponents.get(id).copied().unwrap_or(Rational::ZERO)
270 }
271
272 pub fn iter(&self) -> impl Iterator<Item = (&BaseDimId, &Rational)> {
274 self.exponents.iter()
275 }
276
277 pub fn pow(&self, exp: Rational) -> Result<Self, RationalError> {
282 if exp.is_zero() {
283 return Ok(Self::dimensionless());
284 }
285 let mut exponents = BTreeMap::new();
286 for (id, &e) in &self.exponents {
287 let new_exp = (e * exp)?;
288 if !new_exp.is_zero() {
289 exponents.insert(id.clone(), new_exp);
290 }
291 }
292 Ok(Self { exponents })
293 }
294
295 pub fn pow_int(&self, n: i32) -> Result<Self, RationalError> {
300 self.pow(Rational::from_int(n))
301 }
302
303 #[must_use]
308 pub const fn display_with<'a>(
309 &'a self,
310 names: &'a BTreeMap<BaseDimId, String>,
311 ) -> DimensionDisplay<'a> {
312 DimensionDisplay { dim: self, names }
313 }
314
315 fn write_exponents(
321 &self,
322 w: &mut impl fmt::Write,
323 names: &BTreeMap<BaseDimId, String>,
324 mul_sep: &str,
325 div_sep: &str,
326 ) -> fmt::Result {
327 let mut first = true;
328
329 for (id, &exp) in &self.exponents {
331 if exp.num() <= 0 {
332 continue;
333 }
334 if !first {
335 w.write_str(mul_sep)?;
336 }
337 first = false;
338 let name = names
339 .get(id)
340 .map_or_else(|| id.fallback_symbol(), String::clone);
341 write!(w, "{name}")?;
342 if exp != Rational::ONE {
343 write!(w, "^{exp}")?;
344 }
345 }
346
347 for (id, &exp) in &self.exponents {
349 if exp.num() >= 0 {
350 continue;
351 }
352 let name = names
353 .get(id)
354 .map_or_else(|| id.fallback_symbol(), String::clone);
355 if first {
356 write!(w, "{name}^{exp}")?;
358 first = false;
359 } else {
360 w.write_str(div_sep)?;
361 write!(w, "{name}")?;
362 let pos_exp = -exp;
363 if pos_exp != Rational::ONE {
364 write!(w, "^{pos_exp}")?;
365 }
366 }
367 }
368
369 Ok(())
370 }
371}
372
373pub struct DimensionDisplay<'a> {
375 dim: &'a Dimension,
376 names: &'a BTreeMap<BaseDimId, String>,
377}
378
379impl fmt::Display for DimensionDisplay<'_> {
380 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
381 if self.dim.is_dimensionless() {
382 return write!(f, "Dimensionless");
383 }
384 self.dim.write_exponents(f, self.names, " * ", " / ")
385 }
386}
387
388#[derive(Clone, Copy)]
390enum CombineOp {
391 Add,
393 Sub,
395}
396
397impl Dimension {
398 fn combine(self, other: &Self, op: CombineOp) -> Result<Self, RationalError> {
400 let mut exponents = self.exponents;
401 for (id, exp) in &other.exponents {
402 let entry = exponents.entry(id.clone()).or_insert(Rational::ZERO);
403 *entry = match op {
404 CombineOp::Add => (*entry + *exp)?,
405 CombineOp::Sub => (*entry - *exp)?,
406 };
407 if entry.is_zero() {
408 exponents.remove(id);
409 }
410 }
411 Ok(Self { exponents })
412 }
413}
414
415impl std::ops::Mul for Dimension {
416 type Output = Result<Self, RationalError>;
417 fn mul(self, other: Self) -> Self::Output {
419 self.combine(&other, CombineOp::Add)
420 }
421}
422
423impl std::ops::Div for Dimension {
424 type Output = Result<Self, RationalError>;
425 fn div(self, other: Self) -> Self::Output {
427 self.combine(&other, CombineOp::Sub)
428 }
429}
430
431impl std::ops::Mul for &Dimension {
432 type Output = Result<Dimension, RationalError>;
433 fn mul(self, other: Self) -> Self::Output {
434 self.clone().combine(other, CombineOp::Add)
435 }
436}
437
438impl std::ops::Div for &Dimension {
439 type Output = Result<Dimension, RationalError>;
440 fn div(self, other: Self) -> Self::Output {
441 self.clone().combine(other, CombineOp::Sub)
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 fn r(num: i32, den: i32) -> Rational {
452 Rational::try_new(num, den).expect("non-zero denominator")
453 }
454
455 fn length() -> BaseDimId {
457 BaseDimId::Prelude("Length".to_string())
458 }
459 fn time() -> BaseDimId {
460 BaseDimId::Prelude("Time".to_string())
461 }
462 fn mass() -> BaseDimId {
463 BaseDimId::Prelude("Mass".to_string())
464 }
465
466 fn test_names() -> BTreeMap<BaseDimId, String> {
468 let mut m = BTreeMap::new();
469 m.insert(
470 BaseDimId::Prelude("Length".to_string()),
471 "Length".to_string(),
472 );
473 m.insert(BaseDimId::Prelude("Time".to_string()), "Time".to_string());
474 m.insert(BaseDimId::Prelude("Mass".to_string()), "Mass".to_string());
475 m.insert(
476 BaseDimId::Prelude("Temperature".to_string()),
477 "Temperature".to_string(),
478 );
479 m.insert(
480 BaseDimId::Prelude("ElectricCurrent".to_string()),
481 "ElectricCurrent".to_string(),
482 );
483 m.insert(
484 BaseDimId::Prelude("Amount".to_string()),
485 "Amount".to_string(),
486 );
487 m.insert(
488 BaseDimId::Prelude("LuminousIntensity".to_string()),
489 "LuminousIntensity".to_string(),
490 );
491 m.insert(BaseDimId::Prelude("Angle".to_string()), "Angle".to_string());
492 m
493 }
494
495 #[test]
496 fn rational_creation_and_reduction() {
497 assert_eq!(r(2, 4), r(1, 2));
498 assert_eq!(r(-3, 6), r(-1, 2));
499 assert_eq!(r(6, -4), r(-3, 2));
500 assert_eq!(r(0, 5), Rational::ZERO);
501 }
502
503 #[test]
504 fn rational_arithmetic() {
505 let half = r(1, 2);
506 let third = r(1, 3);
507
508 let sum = (half + third).unwrap();
510 assert_eq!(sum, r(5, 6));
511
512 let diff = (half - third).unwrap();
514 assert_eq!(diff, r(1, 6));
515
516 let prod = (half * third).unwrap();
518 assert_eq!(prod, r(1, 6));
519
520 assert_eq!(-half, r(-1, 2));
522 }
523
524 #[test]
525 fn rational_from_int() {
526 assert_eq!(Rational::from_int(3), r(3, 1));
527 assert_eq!(Rational::from_int(0), Rational::ZERO);
528 assert_eq!(Rational::from_int(-2), r(-2, 1));
529 }
530
531 #[test]
532 fn dimension_base() {
533 let len = Dimension::base(length());
534 assert_eq!(len.get_exponent(&length()), Rational::ONE);
535 assert!(len.get_exponent(&time()).is_zero());
536 assert!(len.get_exponent(&mass()).is_zero());
537 }
538
539 #[test]
540 fn dimension_dimensionless() {
541 assert!(Dimension::dimensionless().is_dimensionless());
542 assert!(!Dimension::base(length()).is_dimensionless());
543 }
544
545 #[test]
546 fn dimension_velocity() {
547 let l = Dimension::base(length());
549 let t = Dimension::base(time());
550 let velocity = (l / t).unwrap();
551
552 assert_eq!(velocity.get_exponent(&length()), Rational::ONE);
553 assert_eq!(velocity.get_exponent(&time()), Rational::from_int(-1));
554 }
555
556 #[test]
557 fn dimension_acceleration() {
558 let l = Dimension::base(length());
560 let t = Dimension::base(time());
561 let accel = (l / t.pow_int(2).unwrap()).unwrap();
562
563 assert_eq!(accel.get_exponent(&length()), Rational::ONE);
564 assert_eq!(accel.get_exponent(&time()), Rational::from_int(-2));
565 }
566
567 #[test]
568 fn dimension_force() {
569 let m = Dimension::base(mass());
571 let l = Dimension::base(length());
572 let t = Dimension::base(time());
573 let force = ((m * l).unwrap() / t.pow_int(2).unwrap()).unwrap();
574
575 assert_eq!(force.get_exponent(&mass()), Rational::ONE);
576 assert_eq!(force.get_exponent(&length()), Rational::ONE);
577 assert_eq!(force.get_exponent(&time()), Rational::from_int(-2));
578 }
579
580 #[test]
581 fn dimension_sqrt() {
582 let area = Dimension::base(length()).pow_int(2).unwrap();
584 let sqrt_area = area.pow(Rational::HALF).unwrap();
585 assert_eq!(sqrt_area, Dimension::base(length()));
586 }
587
588 #[test]
589 fn dimension_mul_div_inverse() {
590 let l = Dimension::base(length());
591 let t = Dimension::base(time());
592 let velocity = (l.clone() / t.clone()).unwrap();
593
594 assert_eq!((velocity.clone() * t.clone()).unwrap(), l);
596
597 assert_eq!((l / velocity).unwrap(), t);
599 }
600
601 #[test]
602 fn dimension_dimensionless_mul() {
603 let l = Dimension::base(length());
604 assert_eq!((Dimension::dimensionless() * l.clone()).unwrap(), l);
605 assert_eq!((l.clone() * Dimension::dimensionless()).unwrap(), l);
606 }
607
608 #[test]
609 fn dimension_display_simple() {
610 let names = test_names();
611 assert_eq!(
612 format!("{}", Dimension::dimensionless().display_with(&names)),
613 "Dimensionless"
614 );
615 assert_eq!(
616 format!("{}", Dimension::base(length()).display_with(&names)),
617 "Length"
618 );
619 }
620
621 #[test]
622 fn dimension_display_velocity() {
623 let names = test_names();
624 let velocity = (Dimension::base(length()) / Dimension::base(time())).unwrap();
625 assert_eq!(
626 format!("{}", velocity.display_with(&names)),
627 "Length / Time"
628 );
629 }
630
631 #[test]
632 fn dimension_display_force() {
633 let names = test_names();
634 let force = ((Dimension::base(mass()) * Dimension::base(length())).unwrap()
635 / Dimension::base(time()).pow_int(2).unwrap())
636 .unwrap();
637 assert_eq!(
638 format!("{}", force.display_with(&names)),
639 "Length * Mass / Time^2"
640 );
641 }
642
643 #[test]
644 fn dimension_display_area() {
645 let names = test_names();
646 let area = Dimension::base(length()).pow_int(2).unwrap();
647 assert_eq!(format!("{}", area.display_with(&names)), "Length^2");
648 }
649
650 #[test]
651 fn dimension_display_frequency() {
652 let names = test_names();
653 let freq = (Dimension::dimensionless() / Dimension::base(time())).unwrap();
655 assert_eq!(format!("{}", freq.display_with(&names)), "Time^-1");
656 }
657
658 #[test]
659 fn dimension_user_defined_base() {
660 let info_id = BaseDimId::UserDefined {
662 dag: crate::dag_id::DagId::root("test"),
663 name: "Information".to_string(),
664 };
665 let information = Dimension::base(info_id.clone());
666 let t = Dimension::base(time());
667 let bandwidth = (information / t).unwrap();
668
669 assert_eq!(bandwidth.get_exponent(&info_id), Rational::ONE);
670 assert_eq!(bandwidth.get_exponent(&time()), Rational::from_int(-1));
671
672 let mut names = test_names();
674 names.insert(info_id, "Information".to_string());
675 assert_eq!(
676 format!("{}", bandwidth.display_with(&names)),
677 "Information / Time"
678 );
679 }
680
681 #[test]
682 fn dimension_hash_consistency() {
683 use std::collections::hash_map::DefaultHasher;
684 use std::hash::{Hash, Hasher};
685
686 let a = (Dimension::base(length()) / Dimension::base(time())).unwrap();
687 let b = (Dimension::base(length()) / Dimension::base(time())).unwrap();
688 assert_eq!(a, b);
689
690 let mut ha = DefaultHasher::new();
691 a.hash(&mut ha);
692 let mut hb = DefaultHasher::new();
693 b.hash(&mut hb);
694 assert_eq!(ha.finish(), hb.finish());
695 }
696
697 mod prop {
698 use super::*;
699 use proptest::prelude::*;
700
701 fn arb_rational() -> impl Strategy<Value = Rational> {
704 (-50i32..=50, -50i32..=50)
705 .prop_filter("denominator must be non-zero", |&(_, d)| d != 0)
706 .prop_map(|(n, d)| Rational::try_new(n, d).expect("filtered d != 0"))
707 }
708
709 const PRELUDE_DIMS: [&str; 8] = [
711 "Length",
712 "Time",
713 "Mass",
714 "Temperature",
715 "ElectricCurrent",
716 "Amount",
717 "LuminousIntensity",
718 "Angle",
719 ];
720
721 fn arb_dimension() -> impl Strategy<Value = Dimension> {
724 proptest::collection::btree_map(0usize..8, arb_rational(), 0..=8).prop_map(|map| {
725 let exponents = map
726 .into_iter()
727 .filter(|(_, r)| !r.is_zero())
728 .map(|(idx, r)| (BaseDimId::Prelude(PRELUDE_DIMS[idx].to_string()), r))
729 .collect();
730 Dimension { exponents }
731 })
732 }
733
734 proptest! {
735 #[test]
738 fn rational_always_reduced(n in -100i32..=100, d in -100i32..=100) {
739 prop_assume!(d != 0);
740 let r = Rational::try_new(n, d).expect("d != 0 by prop_assume");
741 prop_assert!(r.den() > 0, "den must be positive, got {}", r.den());
743 if r.num() != 0 {
745 let g = gcd(r.num().unsigned_abs(), r.den().unsigned_abs());
746 prop_assert_eq!(g, 1, "not reduced: {}/{}", r.num(), r.den());
747 } else {
748 prop_assert_eq!(r.den(), 1, "zero should have den=1, got {}", r.den());
749 }
750 }
751
752 #[test]
753 fn rational_add_commutative(a in arb_rational(), b in arb_rational()) {
754 prop_assert_eq!((a + b).unwrap(), (b + a).unwrap());
755 }
756
757 #[test]
758 fn rational_mul_commutative(a in arb_rational(), b in arb_rational()) {
759 prop_assert_eq!((a * b).unwrap(), (b * a).unwrap());
760 }
761
762 #[test]
763 fn rational_additive_identity(a in arb_rational()) {
764 prop_assert_eq!((a + Rational::ZERO).unwrap(), a);
765 }
766
767 #[test]
768 fn rational_multiplicative_identity(a in arb_rational()) {
769 prop_assert_eq!((a * Rational::ONE).unwrap(), a);
770 }
771
772 #[test]
773 fn rational_additive_inverse(a in arb_rational()) {
774 prop_assert_eq!((a + (-a)).unwrap(), Rational::ZERO);
775 }
776
777 #[test]
778 fn rational_sub_self_is_zero(a in arb_rational()) {
779 prop_assert_eq!((a - a).unwrap(), Rational::ZERO);
780 }
781
782 #[test]
785 fn dimension_mul_commutative(a in arb_dimension(), b in arb_dimension()) {
786 prop_assert_eq!((a.clone() * b.clone()).unwrap(), (b * a).unwrap());
787 }
788
789 #[test]
790 fn dimension_dimensionless_is_mul_identity(a in arb_dimension()) {
791 prop_assert_eq!((a.clone() * Dimension::dimensionless()).unwrap(), a);
792 }
793
794 #[test]
795 fn dimension_self_div_is_dimensionless(a in arb_dimension()) {
796 prop_assert_eq!((a.clone() / a).unwrap(), Dimension::dimensionless());
797 }
798
799 #[test]
800 fn dimension_div_inverse(a in arb_dimension(), b in arb_dimension()) {
801 prop_assert_eq!(((a.clone() / b.clone()).unwrap() * b).unwrap(), a);
803 }
804
805 #[test]
806 fn dimension_pow_int_consistent_with_pow(a in arb_dimension(), n in -3i32..=3) {
807 prop_assert_eq!(a.pow_int(n).unwrap(), a.pow(Rational::from_int(n)).unwrap());
808 }
809
810 #[test]
811 fn dimension_pow_distributes_over_mul(
812 a in arb_dimension(),
813 b in arb_dimension(),
814 r in arb_rational(),
815 ) {
816 prop_assert_eq!(
818 (a.clone() * b.clone()).unwrap().pow(r).unwrap(),
819 (a.pow(r).unwrap() * b.pow(r).unwrap()).unwrap(),
820 );
821 }
822 }
823 }
824}