use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Units(pub [i8; 7]);
impl Units {
pub const DIMENSIONLESS: Self = Self([0, 0, 0, 0, 0, 0, 0]);
pub const METER: Self = Self([1, 0, 0, 0, 0, 0, 0]);
pub const KILOGRAM: Self = Self([0, 1, 0, 0, 0, 0, 0]);
pub const SECOND: Self = Self([0, 0, 1, 0, 0, 0, 0]);
pub const AMPERE: Self = Self([0, 0, 0, 1, 0, 0, 0]);
pub const KELVIN: Self = Self([0, 0, 0, 0, 1, 0, 0]);
pub const MOL: Self = Self([0, 0, 0, 0, 0, 1, 0]);
pub const CANDELA: Self = Self([0, 0, 0, 0, 0, 0, 1]);
pub const NEWTON: Self = Self([1, 1, -2, 0, 0, 0, 0]);
pub const JOULE: Self = Self([2, 1, -2, 0, 0, 0, 0]);
pub const WATT: Self = Self([2, 1, -3, 0, 0, 0, 0]);
pub const PASCAL: Self = Self([-1, 1, -2, 0, 0, 0, 0]);
pub const HERTZ: Self = Self([0, 0, -1, 0, 0, 0, 0]);
pub const COULOMB: Self = Self([0, 0, 1, 1, 0, 0, 0]);
pub const VOLT: Self = Self([2, 1, -3, -1, 0, 0, 0]);
pub const OHM: Self = Self([2, 1, -3, -2, 0, 0, 0]);
#[inline]
pub const fn new(exps: [i8; 7]) -> Self {
Self(exps)
}
#[inline]
pub fn is_dimensionless(&self) -> bool {
self.0 == [0; 7]
}
#[inline]
pub fn mul(&self, other: &Self) -> Self {
let mut result = [0i8; 7];
for (dst, (&a, &b)) in result.iter_mut().zip(self.0.iter().zip(other.0.iter())) {
*dst = a.saturating_add(b);
}
Self(result)
}
#[inline]
pub fn div(&self, other: &Self) -> Self {
let mut result = [0i8; 7];
for (dst, (&a, &b)) in result.iter_mut().zip(self.0.iter().zip(other.0.iter())) {
*dst = a.saturating_sub(b);
}
Self(result)
}
pub fn pow_int(&self, n: i32) -> Result<Self, UnitError> {
let mut result = [0i8; 7];
for (i, &exp) in self.0.iter().enumerate() {
let scaled = i64::from(exp) * i64::from(n);
if scaled < i64::from(i8::MIN) || scaled > i64::from(i8::MAX) {
return Err(UnitError::ExponentOverflow {
dimension: i,
base_exp: exp,
power: n,
});
}
result[i] = scaled as i8;
}
Ok(Self(result))
}
fn dim_symbol(idx: usize) -> &'static str {
match idx {
0 => "m",
1 => "kg",
2 => "s",
3 => "A",
4 => "K",
5 => "mol",
6 => "cd",
_ => "?",
}
}
}
impl fmt::Display for Units {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_dimensionless() {
return write!(f, "1");
}
let superscripts = |n: i8| -> String {
match n {
1 => String::new(), -1 => "\u{207B}\u{00B9}".to_string(), 2 => "\u{00B2}".to_string(),
3 => "\u{00B3}".to_string(),
-2 => "\u{207B}\u{00B2}".to_string(),
-3 => "\u{207B}\u{00B3}".to_string(),
_ => {
let sign = if n < 0 { "\u{207B}" } else { "" };
let digits: String = n
.unsigned_abs()
.to_string()
.chars()
.map(|c| match c {
'0' => '\u{2070}',
'1' => '\u{00B9}',
'2' => '\u{00B2}',
'3' => '\u{00B3}',
'4' => '\u{2074}',
'5' => '\u{2075}',
'6' => '\u{2076}',
'7' => '\u{2077}',
'8' => '\u{2078}',
'9' => '\u{2079}',
other => other,
})
.collect();
format!("{sign}{digits}")
}
}
};
let mut first = true;
for (i, &exp) in self.0.iter().enumerate() {
if exp == 0 {
continue;
}
if !first {
write!(f, "\u{00B7}")?; }
write!(f, "{}{}", Self::dim_symbol(i), superscripts(exp))?;
first = false;
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum UnitError {
IncompatibleAddSub {
left: Units,
right: Units,
},
NonDimensionlessArgument {
op: &'static str,
got: Units,
},
NonRationalPower {
base_units: Units,
},
VarIndexOutOfRange {
index: usize,
n_vars: usize,
},
ExponentOverflow {
dimension: usize,
base_exp: i8,
power: i32,
},
}
impl fmt::Display for UnitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::IncompatibleAddSub { left, right } => {
write!(f, "add/sub unit mismatch: {left} ≠ {right}")
}
Self::NonDimensionlessArgument { op, got } => {
write!(f, "{op} requires a dimensionless argument, got {got}")
}
Self::NonRationalPower { base_units } => {
write!(
f,
"Pow with dimensioned base ({base_units}) requires an integer-constant exponent"
)
}
Self::VarIndexOutOfRange { index, n_vars } => {
write!(
f,
"variable index {index} out of range (var_units has {n_vars} entries)"
)
}
Self::ExponentOverflow {
dimension,
base_exp,
power,
} => {
write!(
f,
"exponent overflow for dimension {dimension}: {base_exp} × {power} exceeds i8"
)
}
}
}
}
impl std::error::Error for UnitError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dimensionless_is_zero_vector() {
assert!(Units::DIMENSIONLESS.is_dimensionless());
assert_eq!(Units::DIMENSIONLESS.0, [0; 7]);
}
#[test]
fn named_units_correct() {
assert_eq!(Units::METER.0, [1, 0, 0, 0, 0, 0, 0]);
assert_eq!(Units::KILOGRAM.0, [0, 1, 0, 0, 0, 0, 0]);
assert_eq!(Units::SECOND.0, [0, 0, 1, 0, 0, 0, 0]);
}
#[test]
fn mul_adds_exponents() {
let result = Units::METER.mul(&Units::SECOND);
assert_eq!(result.0, [1, 0, 1, 0, 0, 0, 0]);
}
#[test]
fn div_subtracts_exponents() {
let result = Units::METER.div(&Units::SECOND);
assert_eq!(result.0, [1, 0, -1, 0, 0, 0, 0]);
}
#[test]
fn pow_int_scales_exponents() {
let result = Units::METER.pow_int(3).expect("no overflow");
assert_eq!(result.0, [3, 0, 0, 0, 0, 0, 0]);
}
#[test]
fn pow_int_overflow_returns_err() {
let result = Units::METER.pow_int(200);
assert!(result.is_err());
}
#[test]
fn display_meter_per_second() {
let v = Units::METER.div(&Units::SECOND);
let s = v.to_string();
assert!(s.contains('m'), "expected 'm' in '{s}'");
assert!(s.contains('s'), "expected 's' in '{s}'");
}
#[test]
fn display_dimensionless() {
assert_eq!(Units::DIMENSIONLESS.to_string(), "1");
}
#[test]
fn derived_newton() {
let newton = Units::KILOGRAM.mul(&Units::METER).mul(
&Units::SECOND
.pow_int(-2)
.expect("pow_int(-2) does not overflow"),
);
assert_eq!(newton, Units::NEWTON);
}
}